/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest.data.statistics;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;

public abstract class LeafValuesComputer<T>
implements Serializable {
    private static final long serialVersionUID = -429848953091775832L;

    public void setValuesForLeaves(ArrayList<TreeRoot> roots, Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        Map leafs = roots.stream().flatMap(r -> r.getLeafs().stream()).collect(Collectors.toMap(TreeNode::getId, Function.identity()));
        Map stats = (Map)dataset.compute(data -> this.computeLeafsStatisticsInPartition(roots, leafs, (BootstrappedDatasetPartition)data), this::mergeLeafStatistics);
        leafs.forEach((id, leaf) -> {
            Object stat = stats.get(id);
            if (stat != null) {
                double leafVal = this.computeLeafValue(stat);
                leaf.setVal(leafVal);
            }
        });
    }

    private Map<NodeId, T> computeLeafsStatisticsInPartition(ArrayList<TreeRoot> roots, Map<NodeId, TreeNode> leafs, BootstrappedDatasetPartition data) {
        HashMap res = new HashMap();
        int sampleId = 0;
        while (sampleId < roots.size()) {
            int sampleIdConst = sampleId++;
            data.forEach(vec -> {
                NodeId leafId = ((TreeRoot)roots.get(sampleIdConst)).getRootNode().predictNextNodeKey((Vector)vec.features());
                if (!leafs.containsKey((Object)leafId)) {
                    throw new IllegalStateException();
                }
                if (!res.containsKey((Object)leafId)) {
                    res.put(leafId, this.createLeafStatsAggregator(sampleIdConst));
                }
                this.addElementToLeafStatistic((T)res.get((Object)leafId), (BootstrappedVector)vec, sampleIdConst);
            });
        }
        return res;
    }

    private Map<NodeId, T> mergeLeafStatistics(Map<NodeId, T> left, Map<NodeId, T> right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        HashSet<NodeId> keys = new HashSet<NodeId>(left.keySet());
        keys.addAll(right.keySet());
        for (NodeId key : keys) {
            if (!left.containsKey((Object)key)) {
                left.put(key, right.get((Object)key));
                continue;
            }
            if (!right.containsKey((Object)key)) continue;
            left.put(key, this.mergeLeafStats(left.get((Object)key), right.get((Object)key)));
        }
        return left;
    }

    protected abstract void addElementToLeafStatistic(T var1, BootstrappedVector var2, int var3);

    protected abstract T mergeLeafStats(T var1, T var2);

    protected abstract T createLeafStatsAggregator(int var1);

    protected abstract double computeLeafValue(T var1);
}

