/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.kmeans;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.Dataset;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class KMeansTrainer
implements SingleLabelDatasetTrainer<KMeansModel> {
    private int k = 2;
    private int maxIterations = 10;
    private double epsilon = 1.0E-4;
    private DistanceMeasure distance = new EuclideanDistance();
    private long seed;

    @Override
    public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        Vector[] centers;
        assert (datasetBuilder != null);
        LabeledDatasetPartitionDataBuilderOnHeap partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap(featureExtractor, lbExtractor);
        try (org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            int cols = (Integer)dataset.compute(Dataset::colSize, (a, b) -> a == null ? b : a);
            centers = this.initClusterCentersRandomly(dataset, this.k);
            boolean converged = false;
            for (int iteration = 0; iteration < this.maxIterations && !converged; ++iteration) {
                DenseLocalOnHeapVector[] newCentroids = new DenseLocalOnHeapVector[this.k];
                TotalCostAndCounts totalRes = this.calcDataForNewCentroids(centers, dataset, cols);
                converged = true;
                for (Integer ind : totalRes.sums.keySet()) {
                    Vector massCenter = totalRes.sums.get(ind).times(1.0 / (double)totalRes.counts.get(ind).intValue());
                    if (converged && this.distance.compute(massCenter, centers[ind]) > this.epsilon * this.epsilon) {
                        converged = false;
                    }
                    newCentroids[ind.intValue()] = massCenter;
                }
                centers = newCentroids;
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return new KMeansModel(centers, this.distance);
    }

    private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers, org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset, int cols) {
        Vector[] finalCenters = centers;
        return (TotalCostAndCounts)dataset.compute(data -> {
            TotalCostAndCounts res = new TotalCostAndCounts();
            int i = 0;
            while (i < data.rowSize()) {
                IgniteBiTuple<Integer, Double> closestCentroid = this.findClosestCentroid(finalCenters, (LabeledVector)data.getRow(i));
                int centroidIdx = (Integer)closestCentroid.get1();
                data.setLabel(i, centroidIdx);
                res.totalCost += ((Double)closestCentroid.get2()).doubleValue();
                res.sums.putIfAbsent(centroidIdx, VectorUtils.zeroes(cols));
                int finalI = i++;
                res.sums.compute(centroidIdx, (ind, v) -> v.plus((Vector)((LabeledVector)data.getRow(finalI)).features()));
                res.counts.merge(centroidIdx, 1, (i1, i2) -> i1 + i2);
            }
            return res;
        }, (a, b) -> a == null ? b : a.merge((TotalCostAndCounts)b));
    }

    private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
        double bestDistance = Double.POSITIVE_INFINITY;
        int bestInd = 0;
        for (int i = 0; i < centers.length; ++i) {
            double dist = this.distance.compute(centers[i], (Vector)pnt.features());
            if (!(dist < bestDistance)) continue;
            bestDistance = dist;
            bestInd = i;
        }
        return new IgniteBiTuple((Object)bestInd, (Object)bestDistance);
    }

    private Vector[] initClusterCentersRandomly(org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset, int k) {
        Vector[] initCenters = new DenseLocalOnHeapVector[k];
        List rndPnts = (List)dataset.compute(data -> {
            ArrayList rndPnt = new ArrayList();
            rndPnt.add(data.getRow(new Random(this.seed).nextInt(data.rowSize())));
            return rndPnt;
        }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
        for (int i = 0; i < k; ++i) {
            LabeledVector rndPnt = (LabeledVector)rndPnts.get(new Random(this.seed).nextInt(rndPnts.size()));
            rndPnts.remove(rndPnt);
            initCenters[i] = rndPnt.features();
        }
        return initCenters;
    }

    public int getK() {
        return this.k;
    }

    public KMeansTrainer withK(int k) {
        this.k = k;
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public KMeansTrainer withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public KMeansTrainer withEpsilon(double epsilon) {
        this.epsilon = epsilon;
        return this;
    }

    public DistanceMeasure getDistance() {
        return this.distance;
    }

    public KMeansTrainer withDistance(DistanceMeasure distance) {
        this.distance = distance;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public KMeansTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }

    private static class TotalCostAndCounts {
        double totalCost;
        ConcurrentHashMap<Integer, Vector> sums = new ConcurrentHashMap();
        ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap();

        private TotalCostAndCounts() {
        }

        TotalCostAndCounts merge(TotalCostAndCounts other) {
            this.totalCost += this.totalCost;
            this.sums = MapUtil.mergeMaps(this.sums, other.sums, Vector::plus, ConcurrentHashMap::new);
            this.counts = MapUtil.mergeMaps(this.counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
            return this;
        }
    }
}

