/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.xgboost;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
import org.apache.ignite.ml.tree.DecisionTreeNode;

public class XGModelComposition
implements IgniteModel<NamedVector, Double> {
    private static final long serialVersionUID = 6765344479174942051L;
    private final Map<String, Integer> dict;
    private ModelsComposition modelsComposition;

    public XGModelComposition(Map<String, Integer> dict, List<DecisionTreeNode> models) {
        this.dict = new HashMap<String, Integer>(dict);
        this.modelsComposition = new ModelsComposition(models, (PredictionsAggregator)new XGModelPredictionsAggregator());
    }

    public Double predict(NamedVector input) {
        return this.modelsComposition.predict(this.reencode(input));
    }

    public Map<String, Integer> getDict() {
        return Collections.unmodifiableMap(this.dict);
    }

    public ModelsComposition getModelsComposition() {
        return this.modelsComposition;
    }

    public void setModelsComposition(ModelsComposition modelsComposition) {
        this.modelsComposition = modelsComposition;
    }

    private Vector reencode(NamedVector vector) {
        SparseVector inputVector = new SparseVector(this.dict.size());
        for (int i = 0; i < this.dict.size(); ++i) {
            inputVector.set(i, Double.NaN);
        }
        for (String key : vector.getKeys()) {
            Integer idx = this.dict.get(key);
            if (idx == null) continue;
            inputVector.set(idx.intValue(), vector.get(key));
        }
        return inputVector;
    }

    private static class XGModelPredictionsAggregator
    implements PredictionsAggregator {
        private static final long serialVersionUID = 1274109586500815229L;

        private XGModelPredictionsAggregator() {
        }

        public Double apply(double[] predictions) {
            double res = 0.0;
            for (double prediction : predictions) {
                res += prediction;
            }
            return 1.0 / (1.0 + Math.exp(-res));
        }
    }
}

