/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.ImmutableProjectAggregateMergeRule;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.immutables.value.Value;

@Value.Enclosing
public class ProjectAggregateMergeRule
extends RelRule<Config>
implements TransformationRule {
    protected ProjectAggregateMergeRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Project project = (Project)call.rel(0);
        final Aggregate aggregate = (Aggregate)call.rel(1);
        final RelOptCluster cluster = aggregate.getCluster();
        ImmutableBitSet bits = RelOptUtil.InputFinder.bits(project.getProjects(), null);
        if (bits.contains(ImmutableBitSet.range(aggregate.getGroupCount(), aggregate.getRowType().getFieldCount())) && ProjectAggregateMergeRule.kindCount(project.getProjects(), SqlKind.CASE) == 0) {
            return;
        }
        final ArrayList<AggregateCall> aggCallList = new ArrayList<AggregateCall>(aggregate.getAggCallList());
        RexShuttle shuttle = new RexShuttle(){

            @Override
            public RexNode visitCall(RexCall call) {
                switch (call.getKind()) {
                    case CASE: {
                        RexLiteral literal;
                        AggregateCall aggCall;
                        int aggCallIndex;
                        ImmutableList<RexNode> operands = call.operands;
                        if (operands.size() != 3 || ((RexNode)operands.get(0)).getKind() != SqlKind.IS_NOT_NULL || ((RexNode)((RexCall)operands.get((int)0)).operands.get(0)).getKind() != SqlKind.INPUT_REF || ((RexNode)operands.get(1)).getKind() != SqlKind.CAST || ((RexNode)((RexCall)operands.get((int)1)).operands.get(0)).getKind() != SqlKind.INPUT_REF || ((RexNode)operands.get(2)).getKind() != SqlKind.LITERAL) break;
                        RexCall isNotNull = (RexCall)operands.get(0);
                        RexInputRef ref0 = (RexInputRef)isNotNull.operands.get(0);
                        RexCall cast = (RexCall)operands.get(1);
                        RexInputRef ref1 = (RexInputRef)cast.operands.get(0);
                        if (ref0.getIndex() != ref1.getIndex() || (aggCallIndex = ref1.getIndex() - aggregate.getGroupCount()) < 0 || (aggCall = aggregate.getAggCallList().get(aggCallIndex)).getAggregation().getKind() != SqlKind.SUM || !Objects.equals((literal = (RexLiteral)operands.get(2)).getValueAs(BigDecimal.class), BigDecimal.ZERO)) break;
                        int j = ProjectAggregateMergeRule.findSum0(cluster.getTypeFactory(), aggCall, aggCallList);
                        return cluster.getRexBuilder().makeInputRef(call.type, j);
                    }
                }
                return super.visitCall(call);
            }
        };
        List<RexNode> projects2 = shuttle.visitList(project.getProjects());
        ImmutableBitSet bits2 = RelOptUtil.InputFinder.bits(projects2, null);
        Mapping mapping = Mappings.create(MappingType.FUNCTION, aggregate.getGroupCount() + aggCallList.size(), -1);
        int j = 0;
        for (int i = 0; i < mapping.getSourceCount(); ++i) {
            if (i < aggregate.getGroupCount()) {
                mapping.set(i, j++);
                continue;
            }
            if (bits2.get(i)) {
                mapping.set(i, j++);
                continue;
            }
            aggCallList.remove(j - aggregate.getGroupCount());
        }
        RelBuilder builder = call.builder();
        builder.push(aggregate.getInput());
        builder.aggregate(builder.groupKey(aggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggregate.groupSets), (List<AggregateCall>)aggCallList);
        builder.project(RexPermuteInputsShuttle.of(mapping).visitList(projects2));
        call.transformTo(builder.build());
    }

    private static int findSum0(RelDataTypeFactory typeFactory, AggregateCall sum, List<AggregateCall> aggCallList) {
        AggregateCall sum0 = AggregateCall.create(SqlStdOperatorTable.SUM0, sum.isDistinct(), sum.isApproximate(), sum.ignoreNulls(), sum.getArgList(), sum.filterArg, sum.distinctKeys, sum.collation, typeFactory.createTypeWithNullability(sum.type, false), null);
        int i = aggCallList.indexOf(sum0);
        if (i >= 0) {
            return i;
        }
        aggCallList.add(sum0);
        return aggCallList.size() - 1;
    }

    private static int kindCount(Iterable<? extends RexNode> nodes, final SqlKind kind) {
        final AtomicInteger kindCount = new AtomicInteger(0);
        new RexVisitorImpl<Void>(true){

            @Override
            public Void visitCall(RexCall call) {
                if (call.getKind() == kind) {
                    kindCount.incrementAndGet();
                }
                return (Void)super.visitCall(call);
            }
        }.visitEach(nodes);
        return kindCount.get();
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableProjectAggregateMergeRule.Config.of().withOperandSupplier(b0 -> b0.operand(Project.class).oneInput(b1 -> b1.operand(Aggregate.class).anyInputs()));

        @Override
        default public ProjectAggregateMergeRule toRule() {
            return new ProjectAggregateMergeRule(this);
        }
    }
}

