/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;

public class GiniHistogram
extends ImpurityHistogram
implements ImpurityComputer<BootstrappedVector, GiniHistogram> {
    private static final long serialVersionUID = 5780670356098827667L;
    private final BucketMeta bucketMeta;
    private final int sampleId;
    private final ArrayList<ObjectHistogram<BootstrappedVector>> hists;
    private final Map<Double, Integer> lblMapping;
    private final Set<Integer> bucketIds;

    public GiniHistogram(int sampleId, Map<Double, Integer> lblMapping, BucketMeta bucketMeta) {
        super(bucketMeta.getFeatureMeta().getFeatureId());
        this.hists = new ArrayList(lblMapping.size());
        this.sampleId = sampleId;
        this.bucketMeta = bucketMeta;
        this.lblMapping = lblMapping;
        this.bucketIds = new TreeSet<Integer>();
        for (int i = 0; i < lblMapping.size(); ++i) {
            this.hists.add(new CountersHistogram(this.bucketIds, bucketMeta, this.featureId, sampleId));
        }
    }

    @Override
    public void addElement(BootstrappedVector vector) {
        Integer lblId = this.lblMapping.get(vector.label());
        this.hists.get(lblId).addElement(vector);
    }

    @Override
    public Optional<Double> getValue(Integer bucketId) {
        throw new IllegalStateException("Gini histogram doesn't support 'getValue' method");
    }

    @Override
    public GiniHistogram plus(GiniHistogram other) {
        GiniHistogram res = new GiniHistogram(this.sampleId, this.lblMapping, this.bucketMeta);
        res.bucketIds.addAll(this.bucketIds);
        res.bucketIds.addAll(other.bucketIds);
        for (int i = 0; i < this.hists.size(); ++i) {
            res.hists.set(i, this.hists.get(i).plus(other.hists.get(i)));
        }
        return res;
    }

    @Override
    public Optional<NodeSplit> findBestSplit() {
        if (this.bucketIds.size() < 2) {
            return Optional.empty();
        }
        double bestImpurity = Double.POSITIVE_INFINITY;
        double bestSplitVal = Double.NEGATIVE_INFINITY;
        int bestBucketId = -1;
        List countersDistribPerCls = this.hists.stream().map(ObjectHistogram::computeDistributionFunction).collect(Collectors.toList());
        double[] totalSampleCntPerLb = countersDistribPerCls.stream().mapToDouble(x -> x.isEmpty() ? 0.0 : (Double)x.lastEntry().getValue()).toArray();
        HashMap<Integer, Double> lastLeftValues = new HashMap<Integer, Double>();
        for (int i = 0; i < this.lblMapping.size(); ++i) {
            lastLeftValues.put(i, 0.0);
        }
        for (Integer bucketId : this.bucketIds) {
            int lbId;
            double totalToleftCnt = 0.0;
            double totalToRightCnt = 0.0;
            double leftImpurity = 0.0;
            double rightImpurity = 0.0;
            for (lbId = 0; lbId < this.lblMapping.size(); ++lbId) {
                Double left = (Double)((TreeMap)countersDistribPerCls.get(lbId)).get(bucketId);
                if (left == null) {
                    left = (Double)lastLeftValues.get(lbId);
                }
                totalToleftCnt += left.doubleValue();
                totalToRightCnt += totalSampleCntPerLb[lbId] - left;
                lastLeftValues.put(lbId, left);
            }
            for (lbId = 0; lbId < this.lblMapping.size(); ++lbId) {
                double toRightCnt;
                Double toLeftCnt = (Double)((TreeMap)countersDistribPerCls.get(lbId)).getOrDefault(bucketId, lastLeftValues.get(lbId));
                if (toLeftCnt > 0.0) {
                    leftImpurity += Math.pow(toLeftCnt, 2.0) / totalToleftCnt;
                }
                if (!((toRightCnt = totalSampleCntPerLb[lbId] - toLeftCnt) > 0.0)) continue;
                rightImpurity += Math.pow(toRightCnt, 2.0) / totalToRightCnt;
            }
            double impurityInBucket = -(leftImpurity + rightImpurity);
            if (!(impurityInBucket <= bestImpurity)) continue;
            bestImpurity = impurityInBucket;
            bestSplitVal = this.bucketMeta.bucketIdToValue(bucketId);
            bestBucketId = bucketId;
        }
        return this.checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestImpurity);
    }

    @Override
    public Set<Integer> buckets() {
        return this.bucketIds;
    }

    ObjectHistogram<BootstrappedVector> getHistForLabel(Double lbl) {
        return this.hists.get(this.lblMapping.get(lbl));
    }

    @Override
    public boolean isEqualTo(GiniHistogram other) {
        HashSet<Integer> unionBuckets = new HashSet<Integer>(this.buckets());
        unionBuckets.addAll(other.bucketIds);
        if (unionBuckets.size() != this.bucketIds.size()) {
            return false;
        }
        HashSet<Double> unionMappings = new HashSet<Double>(this.lblMapping.keySet());
        unionMappings.addAll(other.lblMapping.keySet());
        if (unionMappings.size() != this.lblMapping.size()) {
            return false;
        }
        for (Double lbl : unionMappings) {
            ObjectHistogram<BootstrappedVector> otherHist;
            if (!this.lblMapping.get(lbl).equals(other.lblMapping.get(lbl))) {
                return false;
            }
            ObjectHistogram<BootstrappedVector> thisHist = this.getHistForLabel(lbl);
            if (thisHist.isEqualTo(otherHist = other.getHistForLabel(lbl))) continue;
            return false;
        }
        return true;
    }
}

