/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.regressions.logistic;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Objects;
import java.util.UUID;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.inference.json.JSONModel;
import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;

public final class LogisticRegressionModel
implements IgniteModel<Vector, Double>,
Exportable<LogisticRegressionModel>,
JSONWritable {
    private static final long serialVersionUID = -133984600091550776L;
    private Vector weights;
    private double intercept;
    private boolean isKeepingRawLabels;
    private double threshold = 0.5;

    private LogisticRegressionModel() {
    }

    public LogisticRegressionModel(Vector weights, double intercept) {
        this.weights = weights;
        this.intercept = intercept;
    }

    public LogisticRegressionModel withRawLabels(boolean isKeepingRawLabels) {
        this.isKeepingRawLabels = isKeepingRawLabels;
        return this;
    }

    public LogisticRegressionModel withThreshold(double threshold) {
        this.threshold = threshold;
        return this;
    }

    public LogisticRegressionModel withWeights(Vector weights) {
        this.weights = weights;
        return this;
    }

    public LogisticRegressionModel withIntercept(double intercept) {
        this.intercept = intercept;
        return this;
    }

    public boolean isKeepingRawLabels() {
        return this.isKeepingRawLabels;
    }

    public double threshold() {
        return this.threshold;
    }

    public Vector weights() {
        return this.weights;
    }

    public double intercept() {
        return this.intercept;
    }

    @Override
    public Double predict(Vector input) {
        double res = LogisticRegressionModel.sigmoid(input.dot(this.weights) + this.intercept);
        if (this.isKeepingRawLabels) {
            return res;
        }
        return res - this.threshold > 0.0 ? 1.0 : 0.0;
    }

    private static double sigmoid(double z) {
        return 1.0 / (1.0 + Math.exp(-z));
    }

    @Override
    public <P> void saveModel(Exporter<LogisticRegressionModel, P> exporter, P path) {
        exporter.save(this, path);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LogisticRegressionModel mdl = (LogisticRegressionModel)o;
        return Double.compare(mdl.intercept, this.intercept) == 0 && Double.compare(mdl.threshold, this.threshold) == 0 && Boolean.compare(mdl.isKeepingRawLabels, this.isKeepingRawLabels) == 0 && Objects.equals(this.weights, mdl.weights);
    }

    public int hashCode() {
        return Objects.hash(this.weights, this.intercept, this.isKeepingRawLabels, this.threshold);
    }

    public String toString() {
        if (this.weights.size() < 20) {
            StringBuilder builder = new StringBuilder();
            for (int i = 0; i < this.weights.size(); ++i) {
                double nextItem = i == this.weights.size() - 1 ? this.intercept : this.weights.get(i + 1);
                builder.append(String.format("%.4f", Math.abs(this.weights.get(i)))).append("*x").append(i).append(nextItem > 0.0 ? " + " : " - ");
            }
            builder.append(String.format("%.4f", Math.abs(this.intercept)));
            return builder.toString();
        }
        return "LogisticRegressionModel [weights=" + this.weights + ", intercept=" + this.intercept + ']';
    }

    @Override
    public String toString(boolean pretty) {
        return this.toString();
    }

    public static LogisticRegressionModel fromJSON(Path path) {
        ObjectMapper mapper = new ObjectMapper();
        try {
            LogisticRegressionJSONExportModel logisticRegressionJSONExportModel = (LogisticRegressionJSONExportModel)mapper.readValue(new File(path.toAbsolutePath().toString()), LogisticRegressionJSONExportModel.class);
            return logisticRegressionJSONExportModel.convert();
        }
        catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public void toJSON(Path path) {
        ObjectMapper mapper = new ObjectMapper();
        try {
            LogisticRegressionJSONExportModel exportModel = new LogisticRegressionJSONExportModel(System.currentTimeMillis(), "logReg_" + UUID.randomUUID().toString(), LogisticRegressionModel.class.getSimpleName());
            exportModel.intercept = this.intercept;
            exportModel.isKeepingRawLabels = this.isKeepingRawLabels;
            exportModel.threshold = this.threshold;
            exportModel.weights = this.weights.asArray();
            File file = new File(path.toAbsolutePath().toString());
            mapper.writeValue(file, (Object)exportModel);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static class LogisticRegressionJSONExportModel
    extends JSONModel {
        public double[] weights;
        public double intercept;
        public boolean isKeepingRawLabels;
        public double threshold = 0.5;

        public LogisticRegressionJSONExportModel(Long timestamp, String uid, String modelClass) {
            super(timestamp, uid, modelClass);
        }

        @JsonCreator
        public LogisticRegressionJSONExportModel() {
        }

        public String toString() {
            return "LogisticRegressionJSONExportModel{weights=" + Arrays.toString(this.weights) + ", intercept=" + this.intercept + ", isKeepingRawLabels=" + this.isKeepingRawLabels + ", threshold=" + this.threshold + '}';
        }

        @Override
        public LogisticRegressionModel convert() {
            LogisticRegressionModel logRegMdl = new LogisticRegressionModel();
            logRegMdl.withWeights(VectorUtils.of(this.weights));
            logRegMdl.withIntercept(this.intercept);
            logRegMdl.withRawLabels(this.isKeepingRawLabels);
            logRegMdl.withThreshold(this.threshold);
            return logRegMdl;
        }
    }
}

