/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.util.MutablePair;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.checker.nullness.qual.PolyNull;
import org.joda.time.Instant;

@Internal
class Aggregators {
    Aggregators() {
    }

    static <ValT, AccT, ResT, InT> Aggregator<InT, ?, ResT> value(Combine.CombineFn<ValT, AccT, ResT> fn, ScalaInterop.Fun1<InT, ValT> valueFn, Encoder<AccT> accEnc, Encoder<ResT> outEnc) {
        return new ValueAggregator<ValT, AccT, ResT, InT>(fn, valueFn, accEnc, outEnc);
    }

    static <ValT, AccT, ResT, InT> Aggregator<WindowedValue<InT>, ?, Collection<WindowedValue<ResT>>> windowedValue(Combine.CombineFn<ValT, AccT, ResT> fn, ScalaInterop.Fun1<WindowedValue<InT>, ValT> valueFn, WindowingStrategy<?, ?> windowing, Encoder<BoundedWindow> windowEnc, Encoder<AccT> accEnc, Encoder<WindowedValue<ResT>> outEnc) {
        if (!windowing.needsMerge()) {
            return new NonMergingWindowedAggregator<ValT, AccT, ResT, InT>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
        }
        if (windowing.getWindowFn().getClass().equals(Sessions.class)) {
            return new SessionsAggregator<ValT, AccT, ResT, InT>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
        }
        return new MergingWindowedAggregator<ValT, AccT, ResT, InT>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
    }

    private static abstract class CombineFnAggregator<ValT, AccT, ResT, InT, BuffT, OutT>
    extends Aggregator<InT, BuffT, OutT> {
        private final Combine.CombineFn<ValT, AccT, ResT> fn;
        private final ScalaInterop.Fun1<InT, ValT> valueFn;
        private final Encoder<BuffT> bufferEnc;
        private final Encoder<OutT> outputEnc;

        public CombineFnAggregator(Combine.CombineFn<ValT, AccT, ResT> fn, ScalaInterop.Fun1<InT, ValT> valueFn, Encoder<BuffT> bufferEnc, Encoder<OutT> outputEnc) {
            this.fn = fn;
            this.valueFn = valueFn;
            this.bufferEnc = bufferEnc;
            this.outputEnc = outputEnc;
        }

        protected final ValT value(InT in) {
            return (ValT)this.valueFn.apply(in);
        }

        protected final AccT emptyAcc() {
            return (AccT)this.fn.createAccumulator();
        }

        protected final AccT mergeAccs(AccT a1, AccT a2) {
            return (AccT)this.fn.mergeAccumulators((Iterable)ImmutableList.of(a1, a2));
        }

        protected final AccT addToAcc(AccT acc, ValT val) {
            return (AccT)this.fn.addInput(acc, val);
        }

        protected final ResT extract(AccT acc) {
            return (ResT)this.fn.extractOutput(acc);
        }

        public Encoder<BuffT> bufferEncoder() {
            return this.bufferEnc;
        }

        public Encoder<OutT> outputEncoder() {
            return this.outputEnc;
        }
    }

    private static abstract class WindowedAggregator<ValT, AccT, ResT, InT, W extends @NonNull BoundedWindow, MapT extends Map<W, @NonNull MutablePair<Instant, AccT>>>
    extends CombineFnAggregator<ValT, AccT, ResT, WindowedValue<InT>, MapT, Collection<WindowedValue<ResT>>> {
        private final TimestampCombiner tsCombiner;

        public WindowedAggregator(Combine.CombineFn<ValT, AccT, ResT> combineFn, ScalaInterop.Fun1<WindowedValue<InT>, ValT> valueFn, WindowingStrategy<?, ?> windowing, Encoder<W> windowEnc, Encoder<AccT> accEnc, Encoder<WindowedValue<ResT>> outEnc, Class<MapT> clazz) {
            super(combineFn, valueFn, EncoderHelpers.mapEncoder(windowEnc, EncoderHelpers.mutablePairEncoder(EncoderHelpers.encoderOf(Instant.class), accEnc), clazz), EncoderHelpers.collectionEncoder(outEnc));
            this.tsCombiner = windowing.getTimestampCombiner();
        }

        protected final Instant resolveTimestamp(BoundedWindow w, Instant t1, Instant t2) {
            return this.tsCombiner.merge(w, new Instant[]{t1, t2});
        }

        protected final MutablePair<Instant, AccT> initAcc(ValT value, Instant timestamp) {
            return new MutablePair((Object)timestamp, this.addToAcc(this.emptyAcc(), value));
        }

        protected final <T extends MutablePair<Instant, AccT>> @PolyNull T mergeAccs(W window, @PolyNull T a1, @PolyNull T a2) {
            if (a1 == null || a2 == null) {
                return a1 == null ? a2 : a1;
            }
            return (T)a1.update((Object)this.resolveTimestamp((BoundedWindow)window, (Instant)a1._1, (Instant)a2._1), this.mergeAccs(a1._2, a2._2));
        }

        protected BinaryOperator<@Nullable MutablePair<Instant, AccT>> combiner(W target) {
            return (a1, a2) -> this.mergeAccs(target, a1, a2);
        }

        protected final MutablePair<Instant, AccT> addToAcc(W window, @Nullable MutablePair<Instant, AccT> acc, ValT val, Instant ts) {
            if (acc == null) {
                return this.initAcc(val, ts);
            }
            return acc.update((Object)this.resolveTimestamp((BoundedWindow)window, (Instant)acc._1, ts), this.addToAcc(acc._2, val));
        }

        public final Collection<WindowedValue<ResT>> finish(MapT buffer) {
            return Collections2.transform(buffer.entrySet(), this::windowedValue);
        }

        private WindowedValue<ResT> windowedValue(Map.Entry<W, MutablePair<Instant, AccT>> e) {
            return WindowedValue.of(this.extract(e.getValue()._2), (Instant)((Instant)e.getValue()._1), (BoundedWindow)((BoundedWindow)e.getKey()), (PaneInfo)PaneInfo.NO_FIRING);
        }
    }

    private static class NonMergingWindowedAggregator<ValT, AccT, ResT, InT>
    extends WindowedAggregator<ValT, AccT, ResT, InT, BoundedWindow, Map<BoundedWindow, MutablePair<Instant, AccT>>> {
        public NonMergingWindowedAggregator(Combine.CombineFn<ValT, AccT, ResT> combineFn, ScalaInterop.Fun1<WindowedValue<InT>, ValT> valueFn, WindowingStrategy<?, ?> windowing, Encoder<BoundedWindow> windowEnc, Encoder<AccT> accEnc, Encoder<WindowedValue<ResT>> outEnc) {
            super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, Map.class);
        }

        public Map<BoundedWindow, MutablePair<Instant, AccT>> zero() {
            return new HashMap<BoundedWindow, MutablePair<Instant, AccT>>();
        }

        public final Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(Map<BoundedWindow, MutablePair<Instant, AccT>> buff, WindowedValue<InT> input) {
            Collection windows = input.getWindows();
            return this.reduce(buff, windows, this.value(input), input.getTimestamp());
        }

        protected Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(Map<BoundedWindow, MutablePair<Instant, AccT>> buff, Collection<BoundedWindow> windows, ValT value, Instant timestamp) {
            for (BoundedWindow window : windows) {
                buff.compute(window, (w, acc) -> this.addToAcc(w, acc, value, timestamp));
            }
            return buff;
        }

        public Map<BoundedWindow, MutablePair<Instant, AccT>> merge(Map<BoundedWindow, MutablePair<Instant, AccT>> b1, Map<BoundedWindow, MutablePair<Instant, AccT>> b2) {
            if (b1.isEmpty()) {
                return b2;
            }
            if (b2.isEmpty()) {
                return b1;
            }
            if (b2.size() > b1.size()) {
                return this.merge(b2, b1);
            }
            b2.forEach((w, acc) -> b1.merge((BoundedWindow)w, (MutablePair<Instant, AccT>)acc, this.combiner(w)));
            return b1;
        }
    }

    private static class MergingWindowedAggregator<ValT, AccT, ResT, InT>
    extends NonMergingWindowedAggregator<ValT, AccT, ResT, InT> {
        private final WindowFn<ValT, BoundedWindow> windowFn;

        public MergingWindowedAggregator(Combine.CombineFn<ValT, AccT, ResT> combineFn, ScalaInterop.Fun1<WindowedValue<InT>, ValT> valueFn, WindowingStrategy<?, ?> windowing, Encoder<BoundedWindow> windowEnc, Encoder<AccT> accEnc, Encoder<WindowedValue<ResT>> outEnc) {
            super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc);
            this.windowFn = windowing.getWindowFn();
        }

        @Override
        protected Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(Map<BoundedWindow, MutablePair<Instant, AccT>> buff, Collection<BoundedWindow> windows, ValT value, Instant timestamp) {
            if (buff.isEmpty()) {
                return super.reduce(buff, windows, value, timestamp);
            }
            Function<BoundedWindow, ReduceFn<AccT>> accFn = target -> (acc, w) -> {
                MutablePair accW = (MutablePair)buff.remove(w);
                return accW != null ? this.mergeAccs(w, acc, accW) : this.addToAcc(w, acc, value, timestamp);
            };
            Set<BoundedWindow> unmerged = this.mergeWindows(buff, (Set<BoundedWindow>)ImmutableSet.copyOf(windows), accFn);
            if (!unmerged.isEmpty()) {
                return super.reduce(buff, unmerged, value, timestamp);
            }
            return buff;
        }

        @Override
        public Map<BoundedWindow, MutablePair<Instant, AccT>> merge(Map<BoundedWindow, MutablePair<Instant, AccT>> b1, Map<BoundedWindow, MutablePair<Instant, AccT>> b2) {
            Function<BoundedWindow, ReduceFn<AccT>> reduceFn = target -> (acc, w) -> this.mergeAccs(w, this.mergeAccs(w, acc, (MutablePair)b1.remove(w)), (MutablePair)b2.remove(w));
            Set<BoundedWindow> unmerged = b2.keySet();
            if (!(unmerged = this.mergeWindows(b1, unmerged, reduceFn)).isEmpty()) {
                b2.keySet().retainAll(unmerged);
                return super.merge(b1, b2);
            }
            return b1;
        }

        private Set<BoundedWindow> mergeWindows(final Map<BoundedWindow, MutablePair<Instant, AccT>> buff, final Set<BoundedWindow> newWindows, final Function<BoundedWindow, ReduceFn<AccT>> reduceFn) {
            try {
                final HashSet<BoundedWindow> newUnmerged = new HashSet<BoundedWindow>(newWindows);
                WindowFn<ValT, BoundedWindow> windowFn = this.windowFn;
                Objects.requireNonNull(windowFn);
                this.windowFn.mergeWindows(new WindowFn.MergeContext(windowFn){

                    public Collection<BoundedWindow> windows() {
                        return Sets.union(buff.keySet(), (Set)newWindows);
                    }

                    public void merge(Collection<BoundedWindow> merges, BoundedWindow target) {
                        @Nullable MutablePair<Instant, AccT> merged = merges.stream().reduce(null, (BiFunction)reduceFn.apply(target), this.combiner(target));
                        if (merged != null) {
                            buff.put(target, merged);
                        }
                        newUnmerged.removeAll(merges);
                    }
                });
                return newUnmerged;
            }
            catch (Exception e) {
                throw new RuntimeException("Unable to merge accumulators windows", e);
            }
        }

        private static interface ReduceFn<AccT>
        extends BiFunction<MutablePair<Instant, AccT>, BoundedWindow, MutablePair<Instant, AccT>> {
        }
    }

    private static class SessionsAggregator<ValT, AccT, ResT, InT>
    extends WindowedAggregator<ValT, AccT, ResT, InT, IntervalWindow, TreeMap<IntervalWindow, MutablePair<Instant, AccT>>> {
        SessionsAggregator(Combine.CombineFn<ValT, AccT, ResT> combineFn, ScalaInterop.Fun1<WindowedValue<InT>, ValT> valueFn, WindowingStrategy<?, ?> windowing, Encoder<IntervalWindow> windowEnc, Encoder<AccT> accEnc, Encoder<WindowedValue<ResT>> outEnc) {
            super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, TreeMap.class);
            Preconditions.checkArgument((boolean)windowing.getWindowFn().getClass().equals(Sessions.class));
        }

        public final TreeMap<IntervalWindow, MutablePair<Instant, AccT>> zero() {
            return new TreeMap<IntervalWindow, MutablePair<Instant, AccT>>();
        }

        public TreeMap<IntervalWindow, MutablePair<Instant, AccT>> reduce(TreeMap<IntervalWindow, MutablePair<Instant, AccT>> buff, WindowedValue<InT> input) {
            for (IntervalWindow window : input.getWindows()) {
                MutablePair acc = null;
                IntervalWindow first = null;
                IntervalWindow last = null;
                @Nullable Map.Entry<IntervalWindow, MutablePair<Instant, AccT>> lower = buff.floorEntry(window);
                if (lower != null && window.intersects(lower.getKey())) {
                    acc = lower.getValue();
                    window = window.span(lower.getKey());
                    first = last = lower.getKey();
                }
                for (Map.Entry entry : buff.tailMap(window, false).entrySet()) {
                    MutablePair entryAcc = (MutablePair)entry.getValue();
                    IntervalWindow entryWindow = (IntervalWindow)entry.getKey();
                    if (!window.intersects(entryWindow)) break;
                    window = window.span(entryWindow);
                    MutablePair mutablePair = acc = acc == null ? entryAcc : this.mergeAccs(window, acc, entryAcc);
                    if (first == null) {
                        first = last = entryWindow;
                        continue;
                    }
                    last = entryWindow;
                }
                if (first != null && last != null) {
                    buff.navigableKeySet().subSet(first, true, last, true).clear();
                }
                buff.put(window, this.addToAcc(window, acc, this.value(input), input.getTimestamp()));
            }
            return buff;
        }

        public TreeMap<IntervalWindow, MutablePair<Instant, AccT>> merge(TreeMap<IntervalWindow, MutablePair<Instant, AccT>> b1, TreeMap<IntervalWindow, MutablePair<Instant, AccT>> b2) {
            if (b1.isEmpty()) {
                return b2;
            }
            if (b2.isEmpty()) {
                return b1;
            }
            Object res = this.zero();
            PeekingIterator it1 = Iterators.peekingIterator(b1.entrySet().iterator());
            PeekingIterator it2 = Iterators.peekingIterator(b2.entrySet().iterator());
            MutablePair acc = null;
            IntervalWindow window = null;
            while (it1.hasNext() || it2.hasNext()) {
                Map.Entry nextMin;
                Map.Entry entry = it1.hasNext() && it2.hasNext() ? (((IntervalWindow)((Map.Entry)it1.peek()).getKey()).compareTo((IntervalWindow)((Map.Entry)it2.peek()).getKey()) <= 0 ? (Map.Entry)it1.next() : (Map.Entry)it2.next()) : (nextMin = it1.hasNext() ? (Map.Entry)it1.next() : (Map.Entry)it2.next());
                if (window != null && window.intersects((IntervalWindow)nextMin.getKey())) {
                    window = window.span((IntervalWindow)nextMin.getKey());
                    acc = this.mergeAccs(window, acc, (MutablePair)nextMin.getValue());
                    continue;
                }
                if (window != null && acc != null) {
                    ((TreeMap)res).put(window, acc);
                }
                acc = (MutablePair)nextMin.getValue();
                window = (IntervalWindow)nextMin.getKey();
            }
            if (window != null && acc != null) {
                ((TreeMap)res).put(window, acc);
            }
            return res;
        }
    }

    private static class ValueAggregator<ValT, AccT, ResT, InT>
    extends CombineFnAggregator<ValT, AccT, ResT, InT, AccT, ResT> {
        public ValueAggregator(Combine.CombineFn<ValT, AccT, ResT> fn, ScalaInterop.Fun1<InT, ValT> valueFn, Encoder<AccT> accEnc, Encoder<ResT> outEnc) {
            super(fn, valueFn, accEnc, outEnc);
        }

        public AccT zero() {
            return this.emptyAcc();
        }

        public AccT reduce(AccT buff, InT in) {
            return this.addToAcc(buff, this.value(in));
        }

        public AccT merge(AccT b1, AccT b2) {
            return this.mergeAccs(b1, b2);
        }

        public ResT finish(AccT buff) {
            return this.extract(buff);
        }
    }
}

