/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.knn.ann.KNNModelFormat;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;
import org.jetbrains.annotations.NotNull;

public abstract class NNClassificationModel
implements IgniteModel<Vector, Double>,
Exportable<KNNModelFormat>,
DeployableObject {
    protected int k = 5;
    protected DistanceMeasure distanceMeasure = new EuclideanDistance();
    protected boolean weighted;

    public NNClassificationModel withK(int k) {
        this.k = k;
        return this;
    }

    public NNClassificationModel withWeighted(boolean weighted) {
        this.weighted = weighted;
        return this;
    }

    public NNClassificationModel withDistanceMeasure(DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        return this;
    }

    protected LabeledVectorSet<LabeledVector> buildLabeledDatasetOnListOfVectors(List<LabeledVector> neighborsFromPartitions) {
        LabeledVector[] arr = new LabeledVector[neighborsFromPartitions.size()];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = neighborsFromPartitions.get(i);
        }
        return new LabeledVectorSet(arr);
    }

    @NotNull
    protected LabeledVector[] getKClosestVectors(LabeledVectorSet<LabeledVector> trainingData, TreeMap<Double, Set<Integer>> distanceIdxPairs) {
        LabeledVector[] res;
        if (trainingData.rowSize() <= this.k) {
            res = new LabeledVector[trainingData.rowSize()];
            for (int i = 0; i < trainingData.rowSize(); ++i) {
                res[i] = (LabeledVector)trainingData.getRow(i);
            }
        } else {
            res = new LabeledVector[this.k];
            int i = 0;
            Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
            block1: while (i < this.k) {
                double key = iter.next();
                Set<Integer> idxs = distanceIdxPairs.get(key);
                for (Integer idx : idxs) {
                    res[i] = (LabeledVector)trainingData.getRow(idx);
                    if (++i < this.k) continue;
                    continue block1;
                }
            }
        }
        return res;
    }

    @NotNull
    protected TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledVectorSet<LabeledVector> trainingData) {
        TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<Double, Set<Integer>>();
        for (int i = 0; i < trainingData.rowSize(); ++i) {
            LabeledVector labeledVector = (LabeledVector)trainingData.getRow(i);
            if (labeledVector == null) continue;
            double distance = this.distanceMeasure.compute(v, (Vector)labeledVector.features());
            this.putDistanceIdxPair(distanceIdxPairs, i, distance);
        }
        return distanceIdxPairs;
    }

    protected void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) {
        if (distanceIdxPairs.containsKey(distance)) {
            Set<Integer> idxs = distanceIdxPairs.get(distance);
            idxs.add(i);
        } else {
            HashSet<Integer> idxs = new HashSet<Integer>();
            idxs.add(i);
            distanceIdxPairs.put(distance, idxs);
        }
    }

    protected double getClassWithMaxVotes(Map<Double, Double> clsVotes) {
        return (Double)Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey();
    }

    protected double getClassVoteForVector(boolean weighted, double distance) {
        return weighted ? 1.0 / distance : 1.0;
    }

    public DistanceMeasure getDistanceMeasure() {
        return this.distanceMeasure;
    }

    public int hashCode() {
        int res = 1;
        res = res * 37 + this.k;
        res = res * 37 + this.distanceMeasure.hashCode();
        res = res * 37 + Boolean.hashCode(this.weighted);
        return res;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        NNClassificationModel that = (NNClassificationModel)obj;
        return this.k == that.k && this.distanceMeasure.equals(that.distanceMeasure) && this.weighted == that.weighted;
    }

    public String toString() {
        return this.toString(false);
    }

    @Override
    public String toString(boolean pretty) {
        return ModelTrace.builder("KNNClassificationModel", pretty).addField("k", String.valueOf(this.k)).addField("measure", this.distanceMeasure.getClass().getSimpleName()).addField("weighted", String.valueOf(this.weighted)).toString();
    }

    protected void copyParametersFrom(NNClassificationModel mdl) {
        this.k = mdl.k;
        this.distanceMeasure = mdl.distanceMeasure;
        this.weighted = mdl.weighted;
    }

    @Override
    public abstract <P> void saveModel(Exporter<KNNModelFormat, P> var1, P var2);

    @Override
    public List<Object> getDependencies() {
        return Collections.singletonList(this.distanceMeasure);
    }
}

