/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.regressions.logistic.multiclass;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class LogRegressionMultiClassTrainer<P extends Serializable>
extends SingleLabelDatasetTrainer<LogRegressionMultiClassModel> {
    private UpdatesStrategy updatesStgy = new UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg);
    private int amountOfIterations = 100;
    private int batchSize = 100;
    private int amountOfLocIterations = 100;
    private long seed = 1234L;

    @Override
    public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        List<Double> classes = this.extractClassLabels(datasetBuilder, lbExtractor);
        return this.updateModel((LogRegressionMultiClassModel)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    @Override
    public <K, V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel newMdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        List<Double> classes = this.extractClassLabels(datasetBuilder, lbExtractor);
        if (classes.isEmpty()) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(newMdl);
        }
        LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel();
        classes.forEach(clsLb -> {
            LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withBatchSize(this.batchSize).withLocIterations(this.amountOfLocIterations).withMaxIterations(this.amountOfIterations).withSeed(this.seed);
            IgniteBiFunction lbTransformer = (k, v) -> {
                Double lb = (Double)lbExtractor.apply(k, v);
                if (lb.equals(clsLb)) {
                    return 1.0;
                }
                return 0.0;
            };
            LogisticRegressionModel mdl = Optional.ofNullable(newMdl).flatMap(multiClassModel -> multiClassModel.getModel((Double)clsLb)).map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer)).orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
            multiClsMdl.add((double)clsLb, mdl);
        });
        return multiClsMdl;
    }

    @Override
    protected boolean checkState(LogRegressionMultiClassModel mdl) {
        return true;
    }

    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
        assert (datasetBuilder != null);
        LabelPartitionDataBuilderOnHeap partDataBuilder = new LabelPartitionDataBuilderOnHeap(lbExtractor);
        ArrayList<Double> res = new ArrayList<Double>();
        try (Dataset dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), partDataBuilder);){
            Set clsLabels = (Set)dataset.compute(data -> {
                double[] lbs;
                HashSet<Double> locClsLabels = new HashSet<Double>();
                for (double lb : lbs = data.getY()) {
                    locClsLabels.add(lb);
                }
                return locClsLabels;
            }, (a, b) -> {
                if (a == null) {
                    return b == null ? new HashSet() : b;
                }
                if (b == null) {
                    return a;
                }
                return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
            });
            if (clsLabels != null) {
                res.addAll(clsLabels);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return res;
    }

    public LogRegressionMultiClassTrainer withBatchSize(int batchSize) {
        this.batchSize = batchSize;
        return this;
    }

    public double getBatchSize() {
        return this.batchSize;
    }

    public int getAmountOfIterations() {
        return this.amountOfIterations;
    }

    public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations) {
        this.amountOfIterations = amountOfIterations;
        return this;
    }

    public int getAmountOfLocIterations() {
        return this.amountOfLocIterations;
    }

    public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations) {
        this.amountOfLocIterations = amountOfLocIterations;
        return this;
    }

    public LogRegressionMultiClassTrainer withSeed(long seed) {
        this.seed = seed;
        return this;
    }

    public long seed() {
        return this.seed;
    }

    public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) {
        this.updatesStgy = updatesStgy;
        return this;
    }

    public UpdatesStrategy getUpdatesStgy() {
        return this.updatesStgy;
    }
}

