/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.svm;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class SVMLinearMultiClassClassificationTrainer
implements SingleLabelDatasetTrainer<SVMLinearMultiClassClassificationModel> {
    private int amountOfIterations = 20;
    private int amountOfLocIterations = 50;
    private double lambda = 0.2;

    @Override
    public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        List<Double> classes = this.extractClassLabels(datasetBuilder, lbExtractor);
        SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
        classes.forEach(clsLb -> {
            SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer().withAmountOfIterations(this.amountOfIterations()).withAmountOfLocIterations(this.amountOfLocIterations()).withLambda(this.lambda());
            IgniteBiFunction lbTransformer = (k, v) -> {
                Double lb = (Double)lbExtractor.apply(k, v);
                if (lb.equals(clsLb)) {
                    return 1.0;
                }
                return -1.0;
            };
            multiClsMdl.add((double)clsLb, (SVMLinearBinaryClassificationModel)trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
        });
        return multiClsMdl;
    }

    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
        assert (datasetBuilder != null);
        LabelPartitionDataBuilderOnHeap partDataBuilder = new LabelPartitionDataBuilderOnHeap(lbExtractor);
        ArrayList<Double> res = new ArrayList<Double>();
        try (Dataset dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            Set clsLabels = (Set)dataset.compute(data -> {
                double[] lbs;
                HashSet<Double> locClsLabels = new HashSet<Double>();
                for (double lb : lbs = data.getY()) {
                    locClsLabels.add(lb);
                }
                return locClsLabels;
            }, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()));
            res.addAll(clsLabels);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return res;
    }

    public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
        assert (lambda > 0.0);
        this.lambda = lambda;
        return this;
    }

    public double lambda() {
        return this.lambda;
    }

    public int amountOfIterations() {
        return this.amountOfIterations;
    }

    public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
        this.amountOfIterations = amountOfIterations;
        return this;
    }

    public int amountOfLocIterations() {
        return this.amountOfLocIterations;
    }

    public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
        this.amountOfLocIterations = amountOfLocIterations;
        return this;
    }
}

