/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.svm;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class SVMLinearMultiClassClassificationTrainer
extends SingleLabelDatasetTrainer<SVMLinearMultiClassClassificationModel> {
    private int amountOfIterations = 20;
    private int amountOfLocIterations = 50;
    private double lambda = 0.2;
    private long seed = 1234L;

    @Override
    public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.updateModel((SVMLinearMultiClassClassificationModel)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    @Override
    public <K, V> SVMLinearMultiClassClassificationModel updateModel(SVMLinearMultiClassClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        List<Double> classes = this.extractClassLabels(datasetBuilder, lbExtractor);
        if (classes.isEmpty()) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
        }
        SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
        classes.forEach(clsLb -> {
            SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer().withAmountOfIterations(this.getAmountOfIterations()).withAmountOfLocIterations(this.getAmountOfLocIterations()).withLambda(this.getLambda()).withSeed(this.seed);
            IgniteBiFunction lbTransformer = (k, v) -> {
                Double lb = (Double)lbExtractor.apply(k, v);
                if (lb.equals(clsLb)) {
                    return 1.0;
                }
                return 0.0;
            };
            SVMLinearBinaryClassificationModel updatedMdl = mdl == null ? this.learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer) : this.updateModel(mdl, (Double)clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer);
            multiClsMdl.add((double)clsLb, updatedMdl);
        });
        return multiClsMdl;
    }

    @Override
    protected boolean checkState(SVMLinearMultiClassClassificationModel mdl) {
        return true;
    }

    private <K, V> SVMLinearBinaryClassificationModel learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return svmTrainer.fit((DatasetBuilder)datasetBuilder, (IgniteBiFunction)featureExtractor, (IgniteBiFunction)lbExtractor);
    }

    private <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearMultiClassClassificationModel multiClsMdl, Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return multiClsMdl.getModelForClass(clsLb).map(learnedModel -> svmTrainer.update(learnedModel, datasetBuilder, featureExtractor, lbExtractor)).orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor));
    }

    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
        assert (datasetBuilder != null);
        LabelPartitionDataBuilderOnHeap partDataBuilder = new LabelPartitionDataBuilderOnHeap(lbExtractor);
        ArrayList<Double> res = new ArrayList<Double>();
        try (Dataset dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            Set clsLabels = (Set)dataset.compute(data -> {
                double[] lbs;
                HashSet<Double> locClsLabels = new HashSet<Double>();
                for (double lb : lbs = data.getY()) {
                    locClsLabels.add(lb);
                }
                return locClsLabels;
            }, (a, b) -> {
                if (a == null) {
                    return b == null ? new HashSet() : b;
                }
                if (b == null) {
                    return a;
                }
                return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
            });
            if (clsLabels != null) {
                res.addAll(clsLabels);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return res;
    }

    public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
        assert (lambda > 0.0);
        this.lambda = lambda;
        return this;
    }

    public double getLambda() {
        return this.lambda;
    }

    public int getAmountOfIterations() {
        return this.amountOfIterations;
    }

    public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
        this.amountOfIterations = amountOfIterations;
        return this;
    }

    public int getAmountOfLocIterations() {
        return this.amountOfLocIterations;
    }

    public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
        this.amountOfLocIterations = amountOfLocIterations;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public SVMLinearMultiClassClassificationTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }
}

