/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
import org.apache.ignite.ml.tree.DecisionTreeLeafNode;
import org.apache.ignite.ml.tree.DecisionTreeModel;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.ml.tree.TreeFilter;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;
import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;

public abstract class DecisionTreeTrainer<T extends ImpurityMeasure<T>>
extends SingleLabelDatasetTrainer<DecisionTreeModel> {
    int maxDeep;
    double minImpurityDecrease;
    StepFunctionCompressor<T> compressor;
    private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;
    protected boolean usingIdx = true;

    DecisionTreeTrainer(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
        this.maxDeep = maxDeep;
        this.minImpurityDecrease = minImpurityDecrease;
        this.compressor = compressor;
        this.decisionTreeLeafBuilder = decisionTreeLeafBuilder;
    }

    private static void printTree(DecisionTreeNode node, int depth, StringBuilder builder, boolean pretty, boolean isThen) {
        if (node != null) {
            builder.append(pretty ? String.join((CharSequence)"", Collections.nCopies(depth, "\t")) : "");
            if (node instanceof DecisionTreeLeafNode) {
                DecisionTreeLeafNode leaf = (DecisionTreeLeafNode)node;
                builder.append(String.format("%s return ", isThen ? "then" : "else")).append(String.format("%.4f", leaf.getVal()));
            } else if (node instanceof DecisionTreeConditionalNode) {
                DecisionTreeConditionalNode cond = (DecisionTreeConditionalNode)node;
                String prefix = depth == 0 ? "" : (isThen ? "then " : "else ");
                builder.append(String.format("%sif (x", prefix)).append(cond.getCol()).append(" > ").append(String.format("%.4f", cond.getThreshold())).append(pretty ? ")\n" : ") ");
                DecisionTreeTrainer.printTree(cond.getThenNode(), depth + 1, builder, pretty, true);
                builder.append(pretty ? "\n" : " ");
                DecisionTreeTrainer.printTree(cond.getElseNode(), depth + 1, builder, pretty, false);
            } else {
                throw new IllegalArgumentException();
            }
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public <K, V> DecisionTreeModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new DecisionTreeDataBuilder(preprocessor, this.usingIdx), this.learningEnvironment());){
            DecisionTreeModel decisionTreeModel = this.fit(dataset);
            return decisionTreeModel;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public boolean isUpdateable(DecisionTreeModel mdl) {
        return true;
    }

    public DecisionTreeTrainer<T> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (DecisionTreeTrainer)super.withEnvironmentBuilder(envBuilder);
    }

    @Override
    protected <K, V> DecisionTreeModel updateModel(DecisionTreeModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return (DecisionTreeModel)this.fit(datasetBuilder, preprocessor);
    }

    public <K, V> DecisionTreeModel fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
        return new DecisionTreeModel(this.split(dataset, e -> true, 0, this.getImpurityMeasureCalculator(dataset)));
    }

    private DecisionTreeNode split(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, int deep, ImpurityMeasureCalculator<T> impurityCalc) {
        if (deep >= this.maxDeep) {
            return this.decisionTreeLeafBuilder.createLeafNode(dataset, filter);
        }
        StepFunction<T>[] criterionFunctions = this.calculateImpurityForAllColumns(dataset, filter, impurityCalc, deep);
        if (criterionFunctions == null) {
            return this.decisionTreeLeafBuilder.createLeafNode(dataset, filter);
        }
        SplitPoint splitPnt = this.calculateBestSplitPoint(criterionFunctions);
        if (splitPnt == null) {
            return this.decisionTreeLeafBuilder.createLeafNode(dataset, filter);
        }
        DecisionTreeNode thenNode = this.split(dataset, this.updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc);
        DecisionTreeNode elseNode = this.split(dataset, this.updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc);
        return new DecisionTreeConditionalNode(splitPnt.col, splitPnt.threshold, thenNode, elseNode, null);
    }

    private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc, int depth) {
        return (StepFunction[])dataset.compute(part -> {
            if (this.compressor != null) {
                return this.compressor.compress(impurityCalc.calculate((DecisionTreeData)part, filter, depth));
            }
            return impurityCalc.calculate((DecisionTreeData)part, filter, depth);
        }, this::reduce);
    }

    private SplitPoint calculateBestSplitPoint(StepFunction<T>[] criterionFunctions) {
        SplitPoint<ImpurityMeasure> res = null;
        for (int col = 0; col < criterionFunctions.length; ++col) {
            StepFunction<T> criterionFunctionForCol = criterionFunctions[col];
            double[] arguments = criterionFunctionForCol.getX();
            ImpurityMeasure[] values = criterionFunctionForCol.getY();
            for (int leftSize = 1; leftSize < values.length - 1; ++leftSize) {
                if (!(values[0].impurity() - values[leftSize].impurity() > this.minImpurityDecrease) || res != null && values[leftSize].compareTo(((SplitPoint)res).val) >= 0) continue;
                res = new SplitPoint<ImpurityMeasure>(values[leftSize], col, this.calculateThreshold(arguments, leftSize));
            }
        }
        return res;
    }

    private StepFunction<T>[] reduce(StepFunction<T>[] a, StepFunction<T>[] b) {
        if (a == null) {
            return b;
        }
        if (b == null) {
            return a;
        }
        StepFunction<T>[] res = Arrays.copyOf(a, a.length);
        for (int i = 0; i < res.length; ++i) {
            res[i] = res[i].add(b[i]);
        }
        return res;
    }

    private double calculateThreshold(double[] arguments, int leftSize) {
        return (arguments[leftSize] + arguments[leftSize + 1]) / 2.0;
    }

    private TreeFilter updatePredicateForThenNode(TreeFilter filter, SplitPoint splitPnt) {
        return filter.and(f -> f[splitPnt.col] > splitPnt.threshold);
    }

    private TreeFilter updatePredicateForElseNode(TreeFilter filter, SplitPoint splitPnt) {
        return filter.and(f -> f[splitPnt.col] <= splitPnt.threshold);
    }

    public static String printTree(DecisionTreeNode node, boolean pretty) {
        StringBuilder builder = new StringBuilder();
        DecisionTreeTrainer.printTree(node, 0, builder, pretty, false);
        return builder.toString();
    }

    protected abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> var1);

    private static class SplitPoint<T extends ImpurityMeasure<T>>
    implements Serializable {
        private static final long serialVersionUID = -1758525953544425043L;
        private final T val;
        private final int col;
        private final double threshold;

        SplitPoint(T val, int col, double threshold) {
            this.val = val;
            this.col = col;
            this.threshold = threshold;
        }
    }
}

