/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.h2o;

import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.OrdinalModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.math.primitives.vector.NamedVector;

public class H2OMojoModel
implements Model<NamedVector, Double> {
    private final EasyPredictModelWrapper easyPredict;

    public H2OMojoModel(EasyPredictModelWrapper easyPredict) {
        this.easyPredict = easyPredict;
    }

    public Double predict(NamedVector input) {
        RowData rowData = H2OMojoModel.toRowData(input);
        try {
            AbstractPrediction prediction = this.easyPredict.predict(rowData);
            return H2OMojoModel.extractRawValue(prediction);
        }
        catch (PredictException e) {
            throw new RuntimeException(e);
        }
    }

    private static double extractRawValue(AbstractPrediction prediction) {
        if (prediction instanceof BinomialModelPrediction) {
            return ((BinomialModelPrediction)prediction).labelIndex;
        }
        if (prediction instanceof MultinomialModelPrediction) {
            return ((MultinomialModelPrediction)prediction).labelIndex;
        }
        if (prediction instanceof RegressionModelPrediction) {
            return ((RegressionModelPrediction)prediction).value;
        }
        if (prediction instanceof OrdinalModelPrediction) {
            return ((OrdinalModelPrediction)prediction).labelIndex;
        }
        if (prediction instanceof ClusteringModelPrediction) {
            return ((ClusteringModelPrediction)prediction).cluster;
        }
        throw new UnsupportedOperationException("Prediction " + prediction + " cannot be converted to a raw value.");
    }

    private static RowData toRowData(NamedVector input) {
        RowData row = new RowData();
        for (String key : input.getKeys()) {
            row.put((Object)key, (Object)input.get(key));
        }
        return row;
    }

    public void close() {
    }
}

