/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.kmeans;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.Dataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class KMeansTrainer
extends SingleLabelDatasetTrainer<KMeansModel> {
    private int k = 2;
    private int maxIterations = 10;
    private double epsilon = 1.0E-4;
    private DistanceMeasure distance = new EuclideanDistance();
    private long seed;

    @Override
    public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.updateModel((KMeansModel)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        assert (datasetBuilder != null);
        LabeledDatasetPartitionDataBuilderOnHeap partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap(featureExtractor, lbExtractor);
        try (org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            Integer cols = (Integer)dataset.compute(Dataset::colSize, (a, b) -> {
                if (a == null) {
                    return b == null ? 0 : b;
                }
                if (b == null) {
                    return a;
                }
                return b;
            });
            if (cols == null) {
                KMeansModel kMeansModel = this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
                return kMeansModel;
            }
            Vector[] centers = Optional.ofNullable(mdl).map(KMeansModel::getCenters).orElseGet(() -> this.initClusterCentersRandomly(dataset, this.k));
            boolean converged = false;
            int iteration = 0;
            block14: while (true) {
                if (iteration >= this.maxIterations) return new KMeansModel(centers, this.distance);
                if (converged) return new KMeansModel(centers, this.distance);
                DenseVector[] newCentroids = new DenseVector[this.k];
                TotalCostAndCounts totalRes = this.calcDataForNewCentroids(centers, dataset, cols);
                converged = true;
                for (Integer ind : totalRes.sums.keySet()) {
                    Vector massCenter = totalRes.sums.get(ind).times(1.0 / (double)totalRes.counts.get(ind).intValue());
                    if (converged && this.distance.compute(massCenter, centers[ind]) > this.epsilon * this.epsilon) {
                        converged = false;
                    }
                    newCentroids[ind.intValue()] = massCenter;
                }
                ++iteration;
                int i = 0;
                while (true) {
                    if (i >= centers.length) continue block14;
                    if (newCentroids[i] != null) {
                        centers[i] = newCentroids[i];
                    }
                    ++i;
                }
                break;
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    protected boolean checkState(KMeansModel mdl) {
        return mdl.getCenters().length == this.k && mdl.distanceMeasure().equals(this.distance);
    }

    private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers, org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
        Vector[] finalCenters = centers;
        return (TotalCostAndCounts)dataset.compute(data -> {
            TotalCostAndCounts res = new TotalCostAndCounts();
            int i = 0;
            while (i < data.rowSize()) {
                IgniteBiTuple<Integer, Double> closestCentroid = this.findClosestCentroid(finalCenters, (LabeledVector)data.getRow(i));
                int centroidIdx = (Integer)closestCentroid.get1();
                data.setLabel(i, centroidIdx);
                res.totalCost += ((Double)closestCentroid.get2()).doubleValue();
                res.sums.putIfAbsent(centroidIdx, VectorUtils.zeroes(cols));
                int finalI = i++;
                res.sums.compute(centroidIdx, (ind, v) -> {
                    Object features = ((LabeledVector)data.getRow(finalI)).features();
                    return v == null ? features : v.plus((Vector)features);
                });
                res.counts.merge(centroidIdx, 1, (i1, i2) -> i1 + i2);
            }
            return res;
        }, (a, b) -> {
            if (a == null) {
                return b == null ? new TotalCostAndCounts() : b;
            }
            if (b == null) {
                return a;
            }
            return a.merge((TotalCostAndCounts)b);
        });
    }

    private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
        double bestDistance = Double.POSITIVE_INFINITY;
        int bestInd = 0;
        for (int i = 0; i < centers.length; ++i) {
            double dist = this.distance.compute(centers[i], (Vector)pnt.features());
            if (!(dist < bestDistance)) continue;
            bestDistance = dist;
            bestInd = i;
        }
        return new IgniteBiTuple((Object)bestInd, (Object)bestDistance);
    }

    private Vector[] initClusterCentersRandomly(org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int k) {
        Vector[] initCenters = new DenseVector[k];
        List rndPnts = (List)dataset.compute(data -> {
            ArrayList rndPnt;
            block5: {
                rndPnt = new ArrayList();
                if (data.rowSize() == 0) break block5;
                if (data.rowSize() > k) {
                    Random random = new Random(this.seed);
                    for (int i = 0; i < k; ++i) {
                        HashSet<Integer> uniqueIndices = new HashSet<Integer>();
                        int nextIdx = random.nextInt(data.rowSize());
                        int maxRandomSearch = k;
                        for (int cntr = 0; uniqueIndices.contains(nextIdx) && cntr < maxRandomSearch; ++cntr) {
                            nextIdx = random.nextInt(data.rowSize());
                        }
                        uniqueIndices.add(nextIdx);
                        rndPnt.add(data.getRow(nextIdx));
                    }
                } else {
                    for (int i = 0; i < data.rowSize(); ++i) {
                        rndPnt.add(data.getRow(i));
                    }
                }
            }
            return rndPnt;
        }, (a, b) -> {
            if (a == null) {
                return b == null ? new ArrayList() : b;
            }
            if (b == null) {
                return a;
            }
            return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
        });
        Collections.shuffle(rndPnts);
        if (rndPnts.size() >= k) {
            for (int i = 0; i < k; ++i) {
                LabeledVector rndPnt = (LabeledVector)rndPnts.get(new Random(this.seed).nextInt(rndPnts.size()));
                rndPnts.remove(rndPnt);
                initCenters[i] = rndPnt.features();
            }
        } else {
            throw new RuntimeException("The KMeans Trainer required more than " + k + " vectors to find " + k + " clusters");
        }
        return initCenters;
    }

    public int getAmountOfClusters() {
        return this.k;
    }

    public KMeansTrainer withAmountOfClusters(int k) {
        this.k = k;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public KMeansTrainer withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public KMeansTrainer withEpsilon(double epsilon) {
        this.epsilon = epsilon;
        return this;
    }

    public DistanceMeasure getDistance() {
        return this.distance;
    }

    public KMeansTrainer withDistance(DistanceMeasure distance) {
        this.distance = distance;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public KMeansTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }

    public static class TotalCostAndCounts {
        double totalCost;
        ConcurrentHashMap<Integer, Vector> sums = new ConcurrentHashMap();
        ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap();
        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap();

        TotalCostAndCounts merge(TotalCostAndCounts other) {
            this.totalCost += this.totalCost;
            this.sums = MapUtil.mergeMaps(this.sums, other.sums, Vector::plus, ConcurrentHashMap::new);
            this.counts = MapUtil.mergeMaps(this.counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
            this.centroidStat = MapUtil.mergeMaps(this.centroidStat, other.centroidStat, (m1, m2) -> MapUtil.mergeMaps(m1, m2, (i1, i2) -> i1 + i2, ConcurrentHashMap::new), ConcurrentHashMap::new);
            return this;
        }

        public ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> getCentroidStat() {
            return this.centroidStat;
        }
    }
}

