/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.regression;

import java.util.List;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;

public class KNNRegressionModel
extends KNNClassificationModel {
    private static final long serialVersionUID = -721836321291120543L;

    public KNNRegressionModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        super(dataset);
    }

    @Override
    public Double apply(Vector v) {
        List<LabeledVector> neighbors = this.findKNearestNeighbors(v);
        return this.predictYBasedOn(neighbors, v);
    }

    private double predictYBasedOn(List<LabeledVector> neighbors, Vector v) {
        switch (this.stgy) {
            case SIMPLE: {
                return this.simpleRegression(neighbors);
            }
            case WEIGHTED: {
                return this.weightedRegression(neighbors, v);
            }
        }
        throw new UnsupportedOperationException("Strategy " + this.stgy.name() + " is not supported");
    }

    private double weightedRegression(List<LabeledVector> neighbors, Vector v) {
        double sum = 0.0;
        double div = 0.0;
        for (LabeledVector neighbor : neighbors) {
            double distance = this.distanceMeasure.compute(v, (Vector)neighbor.features());
            sum += (Double)neighbor.label() * distance;
            div += distance;
        }
        return sum / div;
    }

    private double simpleRegression(List<LabeledVector> neighbors) {
        double sum = 0.0;
        for (LabeledVector neighbor : neighbors) {
            sum += ((Double)neighbor.label()).doubleValue();
        }
        return sum / (double)this.k;
    }

    @Override
    public String toString() {
        return this.toString(false);
    }

    @Override
    public String toString(boolean pretty) {
        return ModelTrace.builder("KNNRegressionModel", pretty).addField("k", String.valueOf(this.k)).addField("measure", this.distanceMeasure.getClass().getSimpleName()).addField("strategy", this.stgy.name()).toString();
    }
}

