package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.FlinkMultiJoin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/table/planner/plan/rules/logical/FlinkProjectMultiJoinTransposeRule.class */
public class FlinkProjectMultiJoinTransposeRule extends RelOptRule {
    private static final Logger LOG;
    public static final FlinkProjectMultiJoinTransposeRule INSTANCE;
    static final /* synthetic */ boolean $assertionsDisabled;

    public FlinkProjectMultiJoinTransposeRule(RelBuilderFactory relBuilderFactory) {
        super(operand(Project.class, operand(FlinkMultiJoin.class, any()), new RelOptRuleOperand[0]), relBuilderFactory, null);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Project project = (Project) relOptRuleCall.rel(0);
        FlinkMultiJoin flinkMultiJoin = (FlinkMultiJoin) relOptRuleCall.rel(1);
        ArrayList arrayList = new ArrayList();
        try {
            HashMap hashMap = new HashMap();
            for (int i = 0; i < flinkMultiJoin.getInputs().size(); i++) {
                LinkedHashSet linkedHashSet = new LinkedHashSet();
                LinkedList linkedList = new LinkedList();
                RelNode input = flinkMultiJoin.getInput(i);
                for (int i2 = 0; i2 < input.getRowType().getFieldCount(); i2++) {
                    RelDataTypeField relDataTypeField = input.getRowType().getFieldList().get(i2);
                    int intValue = ((Integer) hashMap.compute(relDataTypeField.getName(), (str, num) -> {
                        return Integer.valueOf(num == null ? -1 : num.intValue() + 1);
                    })).intValue();
                    String valueOf = intValue == -1 ? "" : String.valueOf(intValue);
                    if ((checkFieldInProject(project, flinkMultiJoin.getRowType(), relDataTypeField, valueOf) || checkFieldInRexCall(flinkMultiJoin.getJoinConditions().get(i), flinkMultiJoin.getRowType(), relDataTypeField, valueOf)) && linkedHashSet.add(RexInputRef.of(i2, input.getRowType()))) {
                        linkedList.add(relDataTypeField.getName());
                    }
                }
                if (input.getRowType().getFieldCount() > linkedHashSet.size()) {
                    arrayList.add(relOptRuleCall.builder().push(input).projectNamed(linkedHashSet, linkedList, true).build());
                } else {
                    arrayList.add(input);
                }
            }
            RelDataType newRecordType = getNewRecordType(arrayList, flinkMultiJoin);
            relOptRuleCall.transformTo(getNewProject(relOptRuleCall.builder(), project, flinkMultiJoin, newRecordType, (FlinkMultiJoin) flinkMultiJoin.copy(flinkMultiJoin.getTraitSet(), arrayList, newRecordType, getNewConditionsFromOldJoin(flinkMultiJoin, newRecordType))));
        } catch (Exception e) {
            LOG.warn("Failed to apply the rule");
        }
    }

    private RelNode getNewProject(RelBuilder relBuilder, Project project, FlinkMultiJoin flinkMultiJoin, RelDataType relDataType, FlinkMultiJoin flinkMultiJoin2) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedList linkedList = new LinkedList();
        for (Pair<RexNode, String> pair : project.getNamedProjects()) {
            RexNode transformRexNode = transformRexNode(pair.getKey(), flinkMultiJoin.getRowType(), relDataType);
            if (transformRexNode != null) {
                linkedHashSet.add(transformRexNode);
                linkedList.add(pair.getValue());
            }
        }
        return relBuilder.push(flinkMultiJoin2).projectNamed(linkedHashSet, linkedList, true).build();
    }

    private List<RexNode> getNewConditionsFromOldJoin(FlinkMultiJoin flinkMultiJoin, RelDataType relDataType) {
        LinkedList linkedList = new LinkedList();
        for (RexNode rexNode : flinkMultiJoin.getJoinConditions()) {
            if (rexNode != null) {
                RexCall rexCall = (RexCall) rexNode;
                linkedList.add(rexCall.clone(rexCall.getType(), getNewOperands(rexCall.getOperands(), flinkMultiJoin.getRowType(), relDataType)));
            }
        }
        return linkedList;
    }

    private static RelDataType getNewRecordType(List<RelNode> list, FlinkMultiJoin flinkMultiJoin) {
        RelDataType deriveJoinRowType = SqlValidatorUtil.deriveJoinRowType(list.get(0).getRowType(), list.get(1).getRowType(), flinkMultiJoin.getJoinType(), flinkMultiJoin.getCluster().getTypeFactory(), null, Collections.emptyList());
        for (int i = 2; i < list.size(); i++) {
            deriveJoinRowType = SqlValidatorUtil.deriveJoinRowType(deriveJoinRowType, list.get(i).getRowType(), flinkMultiJoin.getJoinType(), flinkMultiJoin.getCluster().getTypeFactory(), null, Collections.emptyList());
        }
        return deriveJoinRowType;
    }

    private RexNode transformRexNode(RexNode rexNode, RelDataType relDataType, RelDataType relDataType2) {
        if (rexNode instanceof RexCall) {
            RexCall rexCall = (RexCall) rexNode;
            return rexCall.clone(rexCall.getType(), (List) rexCall.getOperands().stream().map(rexNode2 -> {
                return transformRexNode(rexNode2, relDataType, relDataType2);
            }).collect(Collectors.toList()));
        }
        if (!(rexNode instanceof RexInputRef)) {
            throw new IllegalStateException("Unsupported rex node " + rexNode.getClass().getSimpleName());
        }
        RelDataTypeField fieldFromDataType = getFieldFromDataType(relDataType.getFieldList().get(((RexInputRef) rexNode).getIndex()), relDataType2);
        if (fieldFromDataType != null) {
            return RexInputRef.of(fieldFromDataType.getIndex(), relDataType2);
        }
        throw new IllegalStateException("Field not found in new output row type.");
    }

    private List<RexNode> getNewOperands(List<RexNode> list, RelDataType relDataType, RelDataType relDataType2) {
        ArrayList arrayList = new ArrayList();
        for (RexNode rexNode : list) {
            if (rexNode instanceof RexCall) {
                RexCall rexCall = (RexCall) rexNode;
                arrayList.add(rexCall.clone(rexCall.getType(), getNewOperands(rexCall.getOperands(), relDataType, relDataType2)));
            } else {
                if (!(rexNode instanceof RexInputRef)) {
                    throw new IllegalStateException("Unsupported rex node " + rexNode.getClass().getSimpleName());
                }
                arrayList.add(RexInputRef.of(getFieldFromDataType(relDataType.getFieldList().get(((RexInputRef) rexNode).getIndex()), relDataType2).getIndex(), relDataType2));
            }
        }
        if ($assertionsDisabled || arrayList.size() == list.size()) {
            return arrayList;
        }
        throw new AssertionError();
    }

    private RelDataTypeField getFieldFromDataType(RelDataTypeField relDataTypeField, RelDataType relDataType) {
        for (RelDataTypeField relDataTypeField2 : relDataType.getFieldList()) {
            if (relDataTypeField2.getName().equals(relDataTypeField.getName()) && relDataTypeField2.getType() == relDataTypeField.getType()) {
                return relDataTypeField2;
            }
        }
        return null;
    }

    private boolean checkFieldInProject(Project project, RelDataType relDataType, RelDataTypeField relDataTypeField, String str) {
        for (RexNode rexNode : project.getProjects()) {
            if (isRexNodeEqualToField(relDataType, relDataTypeField, rexNode, str) || checkFieldInRexCall(rexNode, relDataType, relDataTypeField, str)) {
                return true;
            }
        }
        return false;
    }

    private boolean checkFieldInRexCall(RexNode rexNode, RelDataType relDataType, RelDataTypeField relDataTypeField, String str) {
        if (!(rexNode instanceof RexCall)) {
            return false;
        }
        for (RexNode rexNode2 : ((RexCall) rexNode).getOperands()) {
            if (rexNode2 instanceof RexCall) {
                if (checkFieldInRexCall(rexNode2, relDataType, relDataTypeField, str)) {
                    return true;
                }
            } else if (isRexNodeEqualToField(relDataType, relDataTypeField, rexNode2, str)) {
                return true;
            }
        }
        return false;
    }

    private static boolean isRexNodeEqualToField(RelDataType relDataType, RelDataTypeField relDataTypeField, RexNode rexNode, String str) {
        return (rexNode instanceof RexInputRef) && rexNode.getType().getSqlTypeName().equals(relDataTypeField.getType().getSqlTypeName()) && relDataType.getFieldNames().get(((RexInputRef) rexNode).getIndex()).equals(new StringBuilder().append(relDataTypeField.getName()).append(str).toString());
    }

    static {
        $assertionsDisabled = !FlinkProjectMultiJoinTransposeRule.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(FlinkProjectMultiJoinTransposeRule.class);
        INSTANCE = new FlinkProjectMultiJoinTransposeRule(RelFactories.LOGICAL_BUILDER);
    }
}
