/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.tree.DecisionTree;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
import org.apache.ignite.ml.tree.leaf.MostCommonDecisionTreeLeafBuilder;

public class DecisionTreeClassificationTrainer
extends DecisionTree<GiniImpurityMeasure> {
    public DecisionTreeClassificationTrainer(int maxDeep, double minImpurityDecrease) {
        this(maxDeep, minImpurityDecrease, null);
    }

    public DecisionTreeClassificationTrainer(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<GiniImpurityMeasure> compressor) {
        super(maxDeep, minImpurityDecrease, compressor, new MostCommonDecisionTreeLeafBuilder());
    }

    @Override
    ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset) {
        Set labels = (Set)dataset.compute(part -> {
            if (part.getLabels() != null) {
                HashSet<Double> list = new HashSet<Double>();
                for (double lb : part.getLabels()) {
                    list.add(lb);
                }
                return list;
            }
            return null;
        }, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            a.addAll(b);
            return a;
        });
        HashMap<Double, Integer> encoder = new HashMap<Double, Integer>();
        int idx = 0;
        for (Double lb : labels) {
            encoder.put(lb, idx++);
        }
        return new GiniImpurityMeasureCalculator(encoder);
    }
}

