/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.classification;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.classification.KNNModelFormat;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.jetbrains.annotations.NotNull;

public class KNNClassificationModel<K, V>
implements Model<Vector, Double>,
Exportable<KNNModelFormat> {
    private static final long serialVersionUID = -127386523291350345L;
    protected int k = 5;
    protected DistanceMeasure distanceMeasure = new EuclideanDistance();
    protected KNNStrategy stgy = KNNStrategy.SIMPLE;
    private Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset;

    public KNNClassificationModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
        this.dataset = dataset;
    }

    @Override
    public Double apply(Vector v) {
        if (this.dataset != null) {
            List<LabeledVector> neighbors = this.findKNearestNeighbors(v);
            return this.classify(neighbors, v, this.stgy);
        }
        throw new IllegalStateException("The train kNN dataset is null");
    }

    @Override
    public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
        KNNModelFormat mdlData = new KNNModelFormat(this.k, this.distanceMeasure, this.stgy);
        exporter.save(mdlData, path);
    }

    public KNNClassificationModel<K, V> withK(int k) {
        this.k = k;
        return this;
    }

    public KNNClassificationModel<K, V> withStrategy(KNNStrategy stgy) {
        this.stgy = stgy;
        return this;
    }

    public KNNClassificationModel<K, V> withDistanceMeasure(DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        return this;
    }

    protected List<LabeledVector> findKNearestNeighbors(Vector v) {
        List neighborsFromPartitions = (List)this.dataset.compute(data -> {
            TreeMap<Double, Set<Integer>> distanceIdxPairs = this.getDistances(v, (LabeledDataset<Double, LabeledVector>)data);
            return Arrays.asList(this.getKClosestVectors((LabeledDataset<Double, LabeledVector>)data, distanceIdxPairs));
        }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
        LabeledDataset<Double, LabeledVector> neighborsToFilter = this.buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
        return Arrays.asList(this.getKClosestVectors(neighborsToFilter, this.getDistances(v, neighborsToFilter)));
    }

    private LabeledDataset<Double, LabeledVector> buildLabeledDatasetOnListOfVectors(List<LabeledVector> neighborsFromPartitions) {
        LabeledVector[] arr = new LabeledVector[neighborsFromPartitions.size()];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = neighborsFromPartitions.get(i);
        }
        return new LabeledDataset(arr);
    }

    @NotNull
    private LabeledVector[] getKClosestVectors(LabeledDataset<Double, LabeledVector> trainingData, TreeMap<Double, Set<Integer>> distanceIdxPairs) {
        LabeledVector[] res;
        if (trainingData.rowSize() <= this.k) {
            res = new LabeledVector[trainingData.rowSize()];
            for (int i = 0; i < trainingData.rowSize(); ++i) {
                res[i] = (LabeledVector)trainingData.getRow(i);
            }
        } else {
            res = new LabeledVector[this.k];
            int i = 0;
            Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
            block1: while (i < this.k) {
                double key = iter.next();
                Set<Integer> idxs = distanceIdxPairs.get(key);
                for (Integer idx : idxs) {
                    res[i] = (LabeledVector)trainingData.getRow(idx);
                    if (++i < this.k) continue;
                    continue block1;
                }
            }
        }
        return res;
    }

    @NotNull
    private TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledDataset<Double, LabeledVector> trainingData) {
        TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<Double, Set<Integer>>();
        for (int i = 0; i < trainingData.rowSize(); ++i) {
            LabeledVector labeledVector = (LabeledVector)trainingData.getRow(i);
            if (labeledVector == null) continue;
            double distance = this.distanceMeasure.compute(v, (Vector)labeledVector.features());
            this.putDistanceIdxPair(distanceIdxPairs, i, distance);
        }
        return distanceIdxPairs;
    }

    private void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) {
        if (distanceIdxPairs.containsKey(distance)) {
            Set<Integer> idxs = distanceIdxPairs.get(distance);
            idxs.add(i);
        } else {
            HashSet<Integer> idxs = new HashSet<Integer>();
            idxs.add(i);
            distanceIdxPairs.put(distance, idxs);
        }
    }

    private double classify(List<LabeledVector> neighbors, Vector v, KNNStrategy stgy) {
        HashMap<Double, Double> clsVotes = new HashMap<Double, Double>();
        for (LabeledVector neighbor : neighbors) {
            double clsLb = (Double)neighbor.label();
            double distance = this.distanceMeasure.compute(v, (Vector)neighbor.features());
            if (clsVotes.containsKey(clsLb)) {
                double clsVote = (Double)clsVotes.get(clsLb);
                clsVotes.put(clsLb, clsVote += this.getClassVoteForVector(stgy, distance));
                continue;
            }
            double val = this.getClassVoteForVector(stgy, distance);
            clsVotes.put(clsLb, val);
        }
        return this.getClassWithMaxVotes(clsVotes);
    }

    private double getClassWithMaxVotes(Map<Double, Double> clsVotes) {
        return (Double)Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey();
    }

    private double getClassVoteForVector(KNNStrategy stgy, double distance) {
        if (stgy.equals((Object)KNNStrategy.WEIGHTED)) {
            return 1.0 / distance;
        }
        return 1.0;
    }

    public int hashCode() {
        int res = 1;
        res = res * 37 + this.k;
        res = res * 37 + this.distanceMeasure.hashCode();
        res = res * 37 + this.stgy.hashCode();
        return res;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        KNNClassificationModel that = (KNNClassificationModel)obj;
        return this.k == that.k && this.distanceMeasure.equals(that.distanceMeasure) && this.stgy.equals((Object)that.stgy);
    }
}

