/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.svm;

import java.util.Random;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.AbstractVector;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
import org.apache.ignite.ml.structures.Dataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.svm.Deltas;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;

public class SVMLinearBinaryClassificationTrainer
extends SingleLabelDatasetTrainer<SVMLinearBinaryClassificationModel> {
    private int amountOfIterations = 200;
    private int amountOfLocIterations = 100;
    private double lambda = 0.4;
    private long seed = 1234L;

    @Override
    public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.updateModel((SVMLinearBinaryClassificationModel)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearBinaryClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        assert (datasetBuilder != null);
        IgniteBiFunction patchedLbExtractor = (k, v) -> {
            Double lb = (Double)lbExtractor.apply(k, v);
            if (lb == 0.0) {
                return -1.0;
            }
            return lb;
        };
        LabeledDatasetPartitionDataBuilderOnHeap partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap(featureExtractor, patchedLbExtractor);
        try (org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            Vector weights;
            if (mdl == null) {
                int cols = (Integer)dataset.compute(Dataset::colSize, (a, b) -> {
                    if (a == null) {
                        return b == null ? 0 : b;
                    }
                    if (b == null) {
                        return a;
                    }
                    return b;
                });
                int weightVectorSizeWithIntercept = cols + 1;
                weights = this.initializeWeightsWithZeros(weightVectorSizeWithIntercept);
            } else {
                weights = this.getStateVector(mdl);
            }
            int i = 0;
            while (i < this.getAmountOfIterations()) {
                Vector deltaWeights = this.calculateUpdates(weights, dataset);
                if (deltaWeights == null) {
                    SVMLinearBinaryClassificationModel sVMLinearBinaryClassificationModel = this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
                    return sVMLinearBinaryClassificationModel;
                }
                weights = weights.plus(deltaWeights);
                ++i;
            }
            return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    protected boolean checkState(SVMLinearBinaryClassificationModel mdl) {
        return true;
    }

    private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) {
        double intercept = mdl.intercept();
        Vector weights = mdl.weights();
        int stateVectorSize = weights.size() + 1;
        AbstractVector res = weights.isDense() ? new DenseVector(stateVectorSize) : new SparseVector(stateVectorSize, 1002);
        res.set(0, intercept);
        weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get()));
        return res;
    }

    @NotNull
    private Vector initializeWeightsWithZeros(int vectorSize) {
        return new DenseVector(vectorSize);
    }

    private Vector calculateUpdates(Vector weights, org.apache.ignite.ml.dataset.Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        return (Vector)dataset.compute(data -> {
            Vector copiedWeights = weights.copy();
            Vector deltaWeights = this.initializeWeightsWithZeros(weights.size());
            int amountOfObservation = data.rowSize();
            Vector tmpAlphas = this.initializeWeightsWithZeros(amountOfObservation);
            Vector deltaAlphas = this.initializeWeightsWithZeros(amountOfObservation);
            Random random = new Random(this.seed);
            for (int i = 0; i < this.getAmountOfLocIterations(); ++i) {
                int randomIdx = random.nextInt(amountOfObservation);
                Deltas deltas = this.getDeltas((LabeledVectorSet)data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx);
                copiedWeights = copiedWeights.plus(deltas.deltaWeights);
                deltaWeights = deltaWeights.plus(deltas.deltaWeights);
                tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha);
                deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha);
            }
            return deltaWeights;
        }, (a, b) -> {
            if (a == null) {
                return b == null ? new DenseVector() : b;
            }
            if (b == null) {
                return a;
            }
            return a.plus((Vector)b);
        });
    }

    private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas, int randomIdx) {
        LabeledVector row = (LabeledVector)data.getRow(randomIdx);
        Double lb = (Double)row.label();
        Vector v = this.makeVectorWithInterceptElement(row);
        double alpha = tmpAlphas.get(randomIdx);
        return this.maximize(lb, v, alpha, copiedWeights, amountOfObservation);
    }

    private Vector makeVectorWithInterceptElement(LabeledVector row) {
        Vector vec = row.features().like(row.features().size() + 1);
        vec.set(0, 1.0);
        for (int j = 0; j < row.features().size(); ++j) {
            vec.set(j + 1, row.features().get(j));
        }
        return vec;
    }

    private Deltas maximize(double lb, Vector v, double alpha, Vector weights, int amountOfObservation) {
        double gradient = this.calcGradient(lb, v, weights, amountOfObservation);
        double prjGrad = this.calculateProjectionGradient(alpha, gradient);
        return this.calcDeltas(lb, v, alpha, prjGrad, weights.size(), amountOfObservation);
    }

    private Deltas calcDeltas(double lb, Vector v, double alpha, double gradient, int vectorSize, int amountOfObservation) {
        if (gradient != 0.0) {
            double qii = v.dot(v);
            double newAlpha = this.calcNewAlpha(alpha, gradient, qii);
            Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.getLambda() * (double)amountOfObservation));
            return new Deltas(newAlpha - alpha, deltaWeights);
        }
        return new Deltas(0.0, this.initializeWeightsWithZeros(vectorSize));
    }

    private double calcNewAlpha(double alpha, double gradient, double qii) {
        if (qii != 0.0) {
            return Math.min(Math.max(alpha - gradient / qii, 0.0), 1.0);
        }
        return 1.0;
    }

    private double calcGradient(double lb, Vector v, Vector weights, int amountOfObservation) {
        double dotProduct = v.dot(weights);
        return (lb * dotProduct - 1.0) * (this.getLambda() * (double)amountOfObservation);
    }

    private double calculateProjectionGradient(double alpha, double gradient) {
        if (alpha <= 0.0) {
            return Math.min(gradient, 0.0);
        }
        if (alpha >= 1.0) {
            return Math.max(gradient, 0.0);
        }
        return gradient;
    }

    public SVMLinearBinaryClassificationTrainer withLambda(double lambda) {
        assert (lambda > 0.0);
        this.lambda = lambda;
        return this;
    }

    public double getLambda() {
        return this.lambda;
    }

    public int getAmountOfIterations() {
        return this.amountOfIterations;
    }

    public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfIterations) {
        this.amountOfIterations = amountOfIterations;
        return this;
    }

    public int getAmountOfLocIterations() {
        return this.amountOfLocIterations;
    }

    public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
        this.amountOfLocIterations = amountOfLocIterations;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public SVMLinearBinaryClassificationTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }
}

