/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.utils.indices;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.ignite.ml.knn.utils.PointWithDistance;
import org.apache.ignite.ml.knn.utils.PointWithDistanceUtil;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndex;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;

public class BallTreeSpatialIndex<L>
implements SpatialIndex<L> {
    private static final int MAX_LEAF_SIZE = 42;
    private static final double SPLIT_BALL_MARGIN = 0.2;
    private final DistanceMeasure distanceMeasure;
    private final TreeNode root;

    public BallTreeSpatialIndex(List<LabeledVector<L>> data, DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        this.root = this.buildTree(data);
    }

    @Override
    public List<LabeledVector<L>> findKClosest(int k, Vector pnt) {
        PriorityQueue heap = new PriorityQueue(Collections.reverseOrder());
        this.root.findKClosest(pnt, heap, k);
        return PointWithDistanceUtil.transformToListOrdered(heap);
    }

    private TreeNode buildTree(List<LabeledVector<L>> data) {
        Vector center = this.calculateCenter(data);
        return this.buildTree(data, center, this.calculateRadius(data, center));
    }

    private TreeNode buildTree(List<LabeledVector<L>> data, Vector center, double radius) {
        if (data.size() <= 42) {
            return new TreeLeafNode(center, radius, data);
        }
        Vector leftCenter = this.calculateCenter(data);
        Vector rightCenter = leftCenter.copy();
        int bestDimForSplit = this.calculateBestDimForSplit(data);
        double min = this.calculateMin(data, bestDimForSplit);
        double max = this.calculateMax(data, bestDimForSplit);
        leftCenter.set(bestDimForSplit, min + (max - min) * 0.2);
        rightCenter.set(bestDimForSplit, min + (max - min) * 0.8);
        ArrayList<LabeledVector<L>> leftBallPnts = new ArrayList<LabeledVector<L>>();
        ArrayList<LabeledVector<L>> rightBallPnts = new ArrayList<LabeledVector<L>>();
        this.splitPoints(data, leftCenter, rightCenter, leftBallPnts, rightBallPnts);
        data.clear();
        return new TreeInnerNode(center, radius, this.buildTree(leftBallPnts, leftCenter, this.calculateRadius(leftBallPnts, leftCenter)), this.buildTree(rightBallPnts, rightCenter, this.calculateRadius(rightBallPnts, rightCenter)));
    }

    private void splitPoints(List<LabeledVector<L>> dataPnts, Vector leftCenter, Vector rightCenter, List<LabeledVector<L>> leftBallPnts, List<LabeledVector<L>> rightBallPnts) {
        for (LabeledVector<L> dataPnt : dataPnts) {
            double distToRightCenter;
            double distToLeftCenter = this.distanceMeasure.compute(leftCenter, (Vector)dataPnt.features());
            List<LabeledVector<L>> targetBallPnts = distToLeftCenter < (distToRightCenter = this.distanceMeasure.compute(rightCenter, (Vector)dataPnt.features())) ? leftBallPnts : rightBallPnts;
            targetBallPnts.add(dataPnt);
        }
    }

    private double calculateRadius(List<LabeledVector<L>> data, Vector center) {
        double radius = 0.0;
        for (LabeledVector<L> dataPnt : data) {
            double distance = this.distanceMeasure.compute(center, (Vector)dataPnt.features());
            radius = Math.max(radius, distance);
        }
        return radius;
    }

    private Vector calculateCenter(List<LabeledVector<L>> data) {
        if (data.isEmpty()) {
            return null;
        }
        double[] center = new double[data.get(0).size()];
        for (int dim = 0; dim < center.length; ++dim) {
            center[dim] = this.calculateMean(data, dim);
        }
        return VectorUtils.of(center);
    }

    private int calculateBestDimForSplit(List<LabeledVector<L>> data) {
        if (data.isEmpty()) {
            return -1;
        }
        double bestStd = 0.0;
        int bestDim = -1;
        for (int dim = 0; dim < data.get(0).size(); ++dim) {
            double std = this.calculateStd(data, dim);
            if (!(std > bestStd)) continue;
            bestStd = std;
            bestDim = dim;
        }
        return bestDim;
    }

    private double calculateMax(List<LabeledVector<L>> data, int dim) {
        double max = Double.NEGATIVE_INFINITY;
        for (LabeledVector<L> dataPnt : data) {
            max = Math.max(max, dataPnt.get(dim));
        }
        return max;
    }

    private double calculateMin(List<LabeledVector<L>> data, int dim) {
        double min = Double.POSITIVE_INFINITY;
        for (LabeledVector<L> dataPnt : data) {
            min = Math.min(min, dataPnt.get(dim));
        }
        return min;
    }

    private double calculateStd(List<LabeledVector<L>> data, int dim) {
        double res = 0.0;
        double mean = this.calculateMean(data, dim);
        for (LabeledVector<L> dataPnt : data) {
            res += Math.pow(dataPnt.get(dim) - mean, 2.0);
        }
        return Math.sqrt(res / (double)data.size());
    }

    private double calculateMean(List<LabeledVector<L>> data, int dim) {
        double res = 0.0;
        for (LabeledVector<L> dataPnt : data) {
            res += dataPnt.get(dim);
        }
        return res / (double)data.size();
    }

    private final class TreeLeafNode
    extends TreeNode {
        private final List<LabeledVector<L>> points;

        TreeLeafNode(Vector center, double radius, List<LabeledVector<L>> points) {
            super(center, radius);
            this.points = points;
        }

        @Override
        void findKClosest(Vector pnt, Queue<PointWithDistance<L>> heap, int k) {
            for (LabeledVector dataPnt : this.points) {
                double distance = BallTreeSpatialIndex.this.distanceMeasure.compute(pnt, (Vector)dataPnt.features());
                PointWithDistanceUtil.tryToAddIntoHeap(heap, k, dataPnt, distance);
            }
        }
    }

    private final class TreeInnerNode
    extends TreeNode {
        private final TreeNode left;
        private final TreeNode right;

        TreeInnerNode(Vector center, double radius, TreeNode left, TreeNode right) {
            super(center, radius);
            this.left = left;
            this.right = right;
        }

        @Override
        void findKClosest(Vector pnt, Queue<PointWithDistance<L>> heap, int k) {
            TreeNode secondaryBranch;
            double distToRightCenter;
            double distToLeftCenter = this.computeDistToCenter(pnt, this.left);
            TreeNode primaryBranch = distToLeftCenter > (distToRightCenter = this.computeDistToCenter(pnt, this.right)) ? this.right : this.left;
            TreeNode treeNode = secondaryBranch = primaryBranch == this.right ? this.left : this.right;
            if (primaryBranch != null) {
                primaryBranch.findKClosest(pnt, heap, k);
            }
            if (secondaryBranch != null) {
                double distToSecondaryBall = this.computeDistToCenter(pnt, secondaryBranch) - secondaryBranch.getRadius();
                if (heap.size() < k || distToSecondaryBall < heap.peek().getDistance()) {
                    secondaryBranch.findKClosest(pnt, heap, k);
                }
            }
        }

        private double computeDistToCenter(Vector pnt, TreeNode node) {
            if (node == null) {
                return Double.MAX_VALUE;
            }
            return BallTreeSpatialIndex.this.distanceMeasure.compute(pnt, node.getCenter());
        }
    }

    private abstract class TreeNode {
        private final Vector center;
        private final double radius;

        TreeNode(Vector center, double radius) {
            this.center = center;
            this.radius = radius;
        }

        abstract void findKClosest(Vector var1, Queue<PointWithDistance<L>> var2, int var3);

        public Vector getCenter() {
            return this.center;
        }

        public double getRadius() {
            return this.radius;
        }
    }
}

