/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNModelFormat;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;

public class KNNClassificationModel
extends NNClassificationModel
implements Exportable<KNNModelFormat> {
    private static final long serialVersionUID = -127386523291350345L;
    private List<Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>>> datasets = new ArrayList<Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>>>();

    public KNNClassificationModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        if (dataset != null) {
            this.datasets.add(dataset);
        }
    }

    @Override
    public Double apply(Vector v) {
        if (!this.datasets.isEmpty()) {
            List<LabeledVector> neighbors = this.findKNearestNeighbors(v);
            return this.classify(neighbors, v, this.stgy);
        }
        throw new IllegalStateException("The train kNN dataset is null");
    }

    @Override
    public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
        KNNModelFormat mdlData = new KNNModelFormat(this.k, this.distanceMeasure, this.stgy);
        exporter.save(mdlData, path);
    }

    protected List<LabeledVector> findKNearestNeighbors(Vector v) {
        List<LabeledVector> neighborsFromPartitions = this.datasets.stream().flatMap(dataset -> this.findKNearestNeighborsInDataset(v, (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>>)dataset).stream()).collect(Collectors.toList());
        LabeledVectorSet<Double, LabeledVector> neighborsToFilter = this.buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
        return Arrays.asList(this.getKClosestVectors(neighborsToFilter, this.getDistances(v, neighborsToFilter)));
    }

    private List<LabeledVector> findKNearestNeighborsInDataset(Vector v, Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        List neighborsFromPartitions = (List)dataset.compute(data -> {
            TreeMap<Double, Set<Integer>> distanceIdxPairs = this.getDistances(v, (LabeledVectorSet<Double, LabeledVector>)data);
            return Arrays.asList(this.getKClosestVectors((LabeledVectorSet<Double, LabeledVector>)data, distanceIdxPairs));
        }, (a, b) -> {
            if (a == null) {
                return b == null ? new ArrayList() : b;
            }
            if (b == null) {
                return a;
            }
            return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
        });
        if (neighborsFromPartitions == null) {
            return Collections.emptyList();
        }
        LabeledVectorSet<Double, LabeledVector> neighborsToFilter = this.buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
        return Arrays.asList(this.getKClosestVectors(neighborsToFilter, this.getDistances(v, neighborsToFilter)));
    }

    private double classify(List<LabeledVector> neighbors, Vector v, NNStrategy stgy) {
        HashMap<Double, Double> clsVotes = new HashMap<Double, Double>();
        for (LabeledVector neighbor : neighbors) {
            double clsLb = (Double)neighbor.label();
            double distance = this.distanceMeasure.compute(v, (Vector)neighbor.features());
            if (clsVotes.containsKey(clsLb)) {
                double clsVote = (Double)clsVotes.get(clsLb);
                clsVotes.put(clsLb, clsVote += this.getClassVoteForVector(stgy, distance));
                continue;
            }
            double val = this.getClassVoteForVector(stgy, distance);
            clsVotes.put(clsLb, val);
        }
        return this.getClassWithMaxVotes(clsVotes);
    }

    public void copyStateFrom(KNNClassificationModel mdl) {
        this.copyParametersFrom(mdl);
        this.datasets.addAll(mdl.datasets);
    }
}

