/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.ann;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
import org.apache.ignite.ml.knn.ann.ProbableLabel;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;

public class ANNClassificationTrainer
extends SingleLabelDatasetTrainer<ANNClassificationModel> {
    private int k = 2;
    private int maxIterations = 10;
    private double epsilon = 1.0E-4;
    private DistanceMeasure distance = new EuclideanDistance();
    private long seed;

    @Override
    public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.updateModel((ANNClassificationModel)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    @Override
    protected <K, V> ANNClassificationModel updateModel(ANNClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        CentroidStat centroidStat;
        List<Vector> centers;
        if (mdl != null) {
            centers = Arrays.stream(mdl.getCandidates().data()).map(x -> x.features()).collect(Collectors.toList());
            CentroidStat newStat = this.getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
            if (newStat == null) {
                return mdl;
            }
            CentroidStat oldStat = mdl.getCentroindsStat();
            centroidStat = newStat.merge(oldStat);
        } else {
            centers = this.getCentroids(featureExtractor, lbExtractor, datasetBuilder);
            centroidStat = this.getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
        }
        LabeledVectorSet<ProbableLabel, LabeledVector> dataset = this.buildLabelsForCandidates(centers, centroidStat);
        return new ANNClassificationModel(dataset, centroidStat);
    }

    @Override
    protected boolean checkState(ANNClassificationModel mdl) {
        return mdl.getDistanceMeasure().equals(this.distance) && mdl.getCandidates().rowSize() == this.k;
    }

    @NotNull
    private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(List<Vector> centers, CentroidStat centroidStat) {
        LabeledVector[] arr = new LabeledVector[centers.size()];
        for (int i = 0; i < centers.size(); ++i) {
            arr[i] = new LabeledVector<Vector, ProbableLabel>(centers.get(i), this.fillProbableLabel(i, centroidStat));
        }
        return new LabeledVectorSet(arr);
    }

    private <K, V> List<Vector> getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) {
        KMeansTrainer trainer = new KMeansTrainer().withAmountOfClusters(this.k).withMaxIterations(this.maxIterations).withSeed(this.seed).withDistance(this.distance).withEpsilon(this.epsilon);
        Model mdl = trainer.fit((DatasetBuilder)datasetBuilder, (IgniteBiFunction)featureExtractor, (IgniteBiFunction)lbExtractor);
        return Arrays.asList(((KMeansModel)mdl).getCenters());
    }

    private ProbableLabel fillProbableLabel(int centroidIdx, CentroidStat centroidStat) {
        TreeMap<Double, Double> clsLbls = new TreeMap<Double, Double>();
        centroidStat.clsLblsSet.forEach(t -> clsLbls.put((Double)t, 0.0));
        ConcurrentHashMap<Double, Integer> centroidLbDistribution = centroidStat.centroidStat().get(centroidIdx);
        if (centroidStat.counts.containsKey(centroidIdx)) {
            int clusterSize = centroidStat.counts.get(centroidIdx);
            clsLbls.keySet().forEach(label -> clsLbls.put((Double)label, centroidLbDistribution.containsKey(label) ? (double)((Integer)centroidLbDistribution.get(label)).intValue() / (double)clusterSize : 0.0));
        }
        return new ProbableLabel(clsLbls);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, List<Vector> centers) {
        LabeledDatasetPartitionDataBuilderOnHeap partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap(featureExtractor, lbExtractor);
        try (Dataset dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            CentroidStat centroidStat = (CentroidStat)dataset.compute(data -> {
                CentroidStat res = new CentroidStat();
                for (int i = 0; i < data.rowSize(); ++i) {
                    IgniteBiTuple<Integer, Double> closestCentroid = this.findClosestCentroid(centers, (LabeledVector)data.getRow(i));
                    int centroidIdx = (Integer)closestCentroid.get1();
                    double lb = data.label(i);
                    res.labels().add(lb);
                    ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);
                    if (centroidStat == null) {
                        centroidStat = new ConcurrentHashMap();
                        centroidStat.put(lb, 1);
                        res.centroidStat.put(centroidIdx, centroidStat);
                    } else {
                        int cnt = centroidStat.getOrDefault(lb, 0);
                        centroidStat.put(lb, cnt + 1);
                    }
                    res.counts.merge(centroidIdx, 1, (i1, i2) -> i1 + i2);
                }
                return res;
            }, (a, b) -> {
                if (a == null) {
                    return b == null ? new CentroidStat() : b;
                }
                if (b == null) {
                    return a;
                }
                return a.merge((CentroidStat)b);
            });
            return centroidStat;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private IgniteBiTuple<Integer, Double> findClosestCentroid(List<Vector> centers, LabeledVector pnt) {
        double bestDistance = Double.POSITIVE_INFINITY;
        int bestInd = 0;
        for (int i = 0; i < centers.size(); ++i) {
            double dist;
            if (centers.get(i) == null || !((dist = this.distance.compute(centers.get(i), (Vector)pnt.features())) < bestDistance)) continue;
            bestDistance = dist;
            bestInd = i;
        }
        return new IgniteBiTuple((Object)bestInd, (Object)bestDistance);
    }

    public int getK() {
        return this.k;
    }

    public ANNClassificationTrainer withK(int k) {
        this.k = k;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public ANNClassificationTrainer withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public ANNClassificationTrainer withEpsilon(double epsilon) {
        this.epsilon = epsilon;
        return this;
    }

    public DistanceMeasure getDistance() {
        return this.distance;
    }

    public ANNClassificationTrainer withDistance(DistanceMeasure distance) {
        this.distance = distance;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public ANNClassificationTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }

    public static class CentroidStat
    implements Serializable {
        private static final long serialVersionUID = 7624883170532045144L;
        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap();
        ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap();
        ConcurrentSkipListSet<Double> clsLblsSet = new ConcurrentSkipListSet();

        CentroidStat merge(CentroidStat other) {
            this.counts = MapUtil.mergeMaps(this.counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
            this.centroidStat = MapUtil.mergeMaps(this.centroidStat, other.centroidStat, (m1, m2) -> MapUtil.mergeMaps(m1, m2, (i1, i2) -> i1 + i2, ConcurrentHashMap::new), ConcurrentHashMap::new);
            this.clsLblsSet.addAll(other.clsLblsSet);
            return this;
        }

        public ConcurrentSkipListSet<Double> labels() {
            return this.clsLblsSet;
        }

        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat() {
            return this.centroidStat;
        }
    }
}

