/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.ml.tree.TreeFilter;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;
import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;

abstract class DecisionTree<T extends ImpurityMeasure<T>>
implements DatasetTrainer<DecisionTreeNode, Double> {
    private final int maxDeep;
    private final double minImpurityDecrease;
    private final StepFunctionCompressor<T> compressor;
    private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;

    DecisionTree(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
        this.maxDeep = maxDeep;
        this.minImpurityDecrease = minImpurityDecrease;
        this.compressor = compressor;
        this.decisionTreeLeafBuilder = decisionTreeLeafBuilder;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(new EmptyContextBuilder(), new DecisionTreeDataBuilder(featureExtractor, lbExtractor));){
            DecisionTreeNode decisionTreeNode = this.split(dataset, e -> true, 0, this.getImpurityMeasureCalculator(dataset));
            return decisionTreeNode;
        }
        catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> var1);

    private DecisionTreeNode split(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, int deep, ImpurityMeasureCalculator<T> impurityCalc) {
        if (deep >= this.maxDeep) {
            return this.decisionTreeLeafBuilder.createLeafNode(dataset, filter);
        }
        StepFunction<T>[] criterionFunctions = this.calculateImpurityForAllColumns(dataset, filter, impurityCalc);
        if (criterionFunctions == null) {
            return this.decisionTreeLeafBuilder.createLeafNode(dataset, filter);
        }
        SplitPoint splitPnt = this.calculateBestSplitPoint(criterionFunctions);
        if (splitPnt == null) {
            return this.decisionTreeLeafBuilder.createLeafNode(dataset, filter);
        }
        return new DecisionTreeConditionalNode(splitPnt.col, splitPnt.threshold, this.split(dataset, this.updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc), this.split(dataset, this.updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc));
    }

    private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc) {
        return (StepFunction[])dataset.compute(part -> {
            if (this.compressor != null) {
                return this.compressor.compress(impurityCalc.calculate(part.filter(filter)));
            }
            return impurityCalc.calculate(part.filter(filter));
        }, this::reduce);
    }

    private SplitPoint calculateBestSplitPoint(StepFunction<T>[] criterionFunctions) {
        SplitPoint<ImpurityMeasure> res = null;
        for (int col = 0; col < criterionFunctions.length; ++col) {
            StepFunction<T> criterionFunctionForCol = criterionFunctions[col];
            double[] arguments = criterionFunctionForCol.getX();
            ImpurityMeasure[] values = criterionFunctionForCol.getY();
            for (int leftSize = 1; leftSize < values.length - 1; ++leftSize) {
                if (!(values[0].impurity() - values[leftSize].impurity() > this.minImpurityDecrease) || res != null && values[leftSize].compareTo(((SplitPoint)res).val) >= 0) continue;
                res = new SplitPoint<ImpurityMeasure>(values[leftSize], col, this.calculateThreshold(arguments, leftSize));
            }
        }
        return res;
    }

    private StepFunction<T>[] reduce(StepFunction<T>[] a, StepFunction<T>[] b) {
        if (a == null) {
            return b;
        }
        if (b == null) {
            return a;
        }
        StepFunction<T>[] res = Arrays.copyOf(a, a.length);
        for (int i = 0; i < res.length; ++i) {
            res[i] = res[i].add(b[i]);
        }
        return res;
    }

    private double calculateThreshold(double[] arguments, int leftSize) {
        return (arguments[leftSize] + arguments[leftSize + 1]) / 2.0;
    }

    private TreeFilter updatePredicateForThenNode(TreeFilter filter, SplitPoint splitPnt) {
        return filter.and(f -> f[splitPnt.col] > splitPnt.threshold);
    }

    private TreeFilter updatePredicateForElseNode(TreeFilter filter, SplitPoint splitPnt) {
        return filter.and(f -> f[splitPnt.col] <= splitPnt.threshold);
    }

    private static class SplitPoint<T extends ImpurityMeasure<T>>
    implements Serializable {
        private static final long serialVersionUID = -1758525953544425043L;
        private final T val;
        private final int col;
        private final double threshold;

        SplitPoint(T val, int col, double threshold) {
            this.val = val;
            this.col = col;
            this.threshold = threshold;
        }
    }
}

