/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.kmeans;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.clustering.kmeans.ClusterizationModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.inference.json.JSONModel;
import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.util.ModelTrace;

public final class KMeansModel
implements ClusterizationModel<Vector, Integer>,
Exportable<KMeansModelFormat>,
JSONWritable,
DeployableObject {
    private Vector[] centers;
    private DistanceMeasure distanceMeasure = new EuclideanDistance();

    public KMeansModel(Vector[] centers, DistanceMeasure distanceMeasure) {
        this.centers = centers;
        this.distanceMeasure = distanceMeasure;
    }

    private KMeansModel() {
    }

    public DistanceMeasure distanceMeasure() {
        return this.distanceMeasure;
    }

    @Override
    public int amountOfClusters() {
        return this.centers.length;
    }

    public KMeansModel withCentroids(Vector[] centers) {
        this.centers = centers;
        return this;
    }

    public KMeansModel withDistanceMeasure(DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        return this;
    }

    public Vector[] centers() {
        return Arrays.copyOf(this.centers, this.centers.length);
    }

    @Override
    public Integer predict(Vector vec) {
        int res = -1;
        double minDist = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.centers.length; ++i) {
            double curDist = this.distanceMeasure.compute(this.centers[i], vec);
            if (!(curDist < minDist)) continue;
            minDist = curDist;
            res = i;
        }
        return res;
    }

    @Override
    public <P> void saveModel(Exporter<KMeansModelFormat, P> exporter, P path) {
        KMeansModelFormat mdlData = new KMeansModelFormat(this.centers, this.distanceMeasure);
        exporter.save(mdlData, path);
    }

    public int hashCode() {
        int res = 1;
        res = res * 37 + this.distanceMeasure.hashCode();
        res = res * 37 + Arrays.hashCode(this.centers);
        return res;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        KMeansModel that = (KMeansModel)obj;
        return this.distanceMeasure.equals(that.distanceMeasure) && Arrays.deepEquals(this.centers, that.centers);
    }

    public String toString() {
        return this.toString(false);
    }

    @Override
    public String toString(boolean pretty) {
        List centersList = Arrays.stream(this.centers).map(x -> Tracer.asAscii(x, "%.4f", false)).collect(Collectors.toList());
        return ModelTrace.builder("KMeansModel", pretty).addField("distance measure", this.distanceMeasure.toString()).addField("centroids", centersList).toString();
    }

    @Override
    public List<Object> getDependencies() {
        return Collections.singletonList(this.distanceMeasure);
    }

    public static KMeansModel fromJSON(Path path) {
        ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        try {
            KMeansJSONExportModel exportModel = (KMeansJSONExportModel)mapper.readValue(new File(path.toAbsolutePath().toString()), KMeansJSONExportModel.class);
            return exportModel.convert();
        }
        catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public void toJSON(Path path) {
        ObjectMapper mapper = new ObjectMapper();
        try {
            KMeansJSONExportModel exportModel = new KMeansJSONExportModel(System.currentTimeMillis(), "ann_" + UUID.randomUUID().toString(), KMeansModel.class.getSimpleName());
            ArrayList<double[]> listOfCenters = new ArrayList<double[]>();
            for (int i = 0; i < this.centers.length; ++i) {
                listOfCenters.add(this.centers[i].asArray());
            }
            exportModel.mdlCenters = listOfCenters;
            exportModel.distanceMeasure = this.distanceMeasure;
            File file = new File(path.toAbsolutePath().toString());
            mapper.writeValue(file, (Object)exportModel);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static class KMeansJSONExportModel
    extends JSONModel {
        public List<double[]> mdlCenters;
        public DistanceMeasure distanceMeasure;

        public KMeansJSONExportModel(Long timestamp, String uid, String modelClass) {
            super(timestamp, uid, modelClass);
        }

        @JsonCreator
        public KMeansJSONExportModel() {
        }

        @Override
        public KMeansModel convert() {
            KMeansModel mdl = new KMeansModel();
            Vector[] centers = new DenseVector[this.mdlCenters.size()];
            for (int i = 0; i < this.mdlCenters.size(); ++i) {
                centers[i] = VectorUtils.of(this.mdlCenters.get(i));
            }
            DistanceMeasure distanceMeasure = this.distanceMeasure;
            mdl.withCentroids(centers);
            mdl.withDistanceMeasure(distanceMeasure);
            return mdl;
        }
    }
}

