/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.catboost;

import ai.catboost.CatBoostError;
import ai.catboost.CatBoostModel;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CatboostClassificationModel
implements Model<NamedVector, Double> {
    private static final Logger logger = LoggerFactory.getLogger(CatboostClassificationModel.class);
    private final CatBoostModel model;

    public CatboostClassificationModel(CatBoostModel model) {
        this.model = model;
    }

    public Double predict(NamedVector input) {
        float[] floatInput = new float[input.size()];
        int index = 0;
        for (String key : this.model.getFeatureNames()) {
            floatInput[index] = (float)input.get(key);
            ++index;
        }
        try {
            double predict = this.model.predict(floatInput, this.model.getFeatureNames()).get(0, 0);
            return Math.pow(1.0 + Math.exp(-predict), -1.0);
        }
        catch (CatBoostError e) {
            throw new RuntimeException(e.getMessage());
        }
    }

    public void close() {
        try {
            this.model.close();
        }
        catch (CatBoostError e) {
            logger.error(e.getMessage());
        }
    }
}

