/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer;

public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<BootstrappedVector, S>>
implements Serializable {
    private static final long serialVersionUID = -4984067145908187508L;

    public Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatistics(ArrayList<TreeRoot> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> nodesToLearn, Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        return (Map)dataset.compute(x -> this.aggregateImpurityStatisticsOnPartition((BootstrappedDatasetPartition)x, roots, histMeta, nodesToLearn), this::reduceImpurityStatistics);
    }

    private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<TreeRoot> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
        Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
        dataset.forEach(vector -> {
            for (int sampleId = 0; sampleId < vector.counters().length; ++sampleId) {
                TreeRoot root;
                NodeId key;
                if (vector.counters()[sampleId] == 0 || !part.containsKey(key = (root = (TreeRoot)roots.get(sampleId)).getRootNode().predictNextNodeKey((Vector)vector.features()))) continue;
                NodeImpurityHistograms statistics = (NodeImpurityHistograms)res.get(key);
                for (Integer featureId : root.getUsedFeatures()) {
                    BucketMeta meta = (BucketMeta)histMeta.get(featureId);
                    if (!statistics.perFeatureStatistics.containsKey(featureId)) {
                        statistics.perFeatureStatistics.put(featureId, this.createImpurityComputerForFeature(sampleId, meta));
                    }
                    ImpurityComputer impurityComputer = (ImpurityComputer)statistics.perFeatureStatistics.get(featureId);
                    impurityComputer.addElement(vector);
                }
            }
        });
        return res;
    }

    private Map<NodeId, NodeImpurityHistograms<S>> reduceImpurityStatistics(Map<NodeId, NodeImpurityHistograms<S>> left, Map<NodeId, NodeImpurityHistograms<S>> right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        HashMap<NodeId, NodeImpurityHistograms<S>> res = new HashMap<NodeId, NodeImpurityHistograms<S>>(left);
        for (NodeId key : right.keySet()) {
            NodeImpurityHistograms<S> rightVal = right.get(key);
            if (!res.containsKey(key)) {
                res.put(key, rightVal);
                continue;
            }
            res.put(key, left.get(key).plus(rightVal));
        }
        return res;
    }

    protected abstract S createImpurityComputerForFeature(int var1, BucketMeta var2);

    public static class NodeImpurityHistograms<S extends ImpurityComputer<BootstrappedVector, S>>
    implements Serializable {
        private static final long serialVersionUID = 2700045747590421768L;
        private final NodeId nodeId;
        private final Map<Integer, S> perFeatureStatistics = new HashMap<Integer, S>();

        public NodeImpurityHistograms(NodeId nodeId) {
            this.nodeId = nodeId;
        }

        public NodeImpurityHistograms<S> plus(NodeImpurityHistograms<S> other) {
            assert (this.nodeId.equals(other.nodeId));
            NodeImpurityHistograms<S> res = new NodeImpurityHistograms<S>(this.nodeId);
            this.addTo(this.perFeatureStatistics, res.perFeatureStatistics);
            this.addTo(other.perFeatureStatistics, res.perFeatureStatistics);
            return res;
        }

        private void addTo(Map<Integer, S> from, Map<Integer, S> to) {
            from.forEach((key, hist) -> {
                if (!to.containsKey(key)) {
                    to.put((Integer)key, (Object)hist);
                } else {
                    ImpurityComputer sumOfHists = ((ImpurityComputer)to.get(key)).plus(hist);
                    to.put((Integer)key, (Object)sumOfHists);
                }
            });
        }

        public NodeId getNodeId() {
            return this.nodeId;
        }

        public Optional<NodeSplit> findBestSplit() {
            return this.perFeatureStatistics.values().stream().flatMap(x -> x.findBestSplit().map(Stream::of).orElse(Stream.empty())).min(Comparator.comparingDouble(NodeSplit::getImpurity));
        }
    }
}

