/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.naivebayes.compound;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;

public class CompoundNaiveBayesModel
implements IgniteModel<Vector, Double>,
Exportable<CompoundNaiveBayesModel>,
DeployableObject {
    private static final long serialVersionUID = -5045925321135798960L;
    private double[] priorProbabilities;
    private double[] labels;
    private GaussianNaiveBayesModel gaussianModel;
    private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList();
    private DiscreteNaiveBayesModel discreteModel;
    private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList();

    @Override
    public <P> void saveModel(Exporter<CompoundNaiveBayesModel, P> exporter, P path) {
        exporter.save(this, path);
    }

    @Override
    public Double predict(Vector vector) {
        double[] probabilityPowers = new double[this.priorProbabilities.length];
        for (int i = 0; i < this.priorProbabilities.length; ++i) {
            probabilityPowers[i] = Math.log(this.priorProbabilities[i]);
        }
        if (this.discreteModel != null) {
            probabilityPowers = CompoundNaiveBayesModel.sum(probabilityPowers, this.discreteModel.probabilityPowers(CompoundNaiveBayesModel.skipFeatures(vector, this.discreteFeatureIdsToSkip)));
        }
        if (this.gaussianModel != null) {
            probabilityPowers = CompoundNaiveBayesModel.sum(probabilityPowers, this.gaussianModel.probabilityPowers(CompoundNaiveBayesModel.skipFeatures(vector, this.gaussianFeatureIdsToSkip)));
        }
        int maxLbIdx = 0;
        for (int i = 0; i < probabilityPowers.length; ++i) {
            if (!(probabilityPowers[i] > probabilityPowers[maxLbIdx])) continue;
            maxLbIdx = i;
        }
        return this.labels[maxLbIdx];
    }

    public GaussianNaiveBayesModel getGaussianModel() {
        return this.gaussianModel;
    }

    public DiscreteNaiveBayesModel getDiscreteModel() {
        return this.discreteModel;
    }

    public CompoundNaiveBayesModel withPriorProbabilities(double[] priorProbabilities) {
        this.priorProbabilities = (double[])priorProbabilities.clone();
        return this;
    }

    public CompoundNaiveBayesModel withLabels(double[] labels) {
        this.labels = (double[])labels.clone();
        return this;
    }

    public CompoundNaiveBayesModel withGaussianModel(GaussianNaiveBayesModel gaussianModel) {
        this.gaussianModel = gaussianModel;
        return this;
    }

    public CompoundNaiveBayesModel withDiscreteModel(DiscreteNaiveBayesModel discreteModel) {
        this.discreteModel = discreteModel;
        return this;
    }

    public CompoundNaiveBayesModel withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) {
        this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip;
        return this;
    }

    public CompoundNaiveBayesModel withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) {
        this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip;
        return this;
    }

    private static double[] sum(double[] arr1, double[] arr2) {
        assert (arr1.length == arr2.length);
        double[] result = new double[arr1.length];
        for (int i = 0; i < arr1.length; ++i) {
            result[i] = arr1[i] + arr2[i];
        }
        return result;
    }

    private static Vector skipFeatures(Vector vector, Collection<Integer> featureIdsToSkip) {
        int newSize = vector.size() - featureIdsToSkip.size();
        double[] newFeaturesValues = new double[newSize];
        int index = 0;
        for (int j = 0; j < vector.size(); ++j) {
            if (featureIdsToSkip.contains(j)) continue;
            newFeaturesValues[index] = vector.get(j);
            ++index;
        }
        return VectorUtils.of(newFeaturesValues);
    }

    @Override
    public List<Object> getDependencies() {
        return Arrays.asList(this.discreteModel, this.gaussianModel);
    }
}

