/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.regressions.logistic.binomial;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.nn.Activators;
import org.apache.ignite.ml.nn.MLPTrainer;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;

public class LogisticRegressionSGDTrainer<P extends Serializable>
extends SingleLabelDatasetTrainer<LogisticRegressionModel> {
    private UpdatesStrategy updatesStgy = new UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg);
    private int maxIterations = 100;
    private int batchSize = 100;
    private int locIterations = 100;
    private long seed = 1234L;

    @Override
    public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.updateModel((LogisticRegressionModel)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    @Override
    protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        Model<Matrix, Matrix> mlp;
        IgniteFunction archSupplier = dataset -> {
            int cols = (Integer)dataset.compute(data -> {
                if (data.getFeatures() == null) {
                    return null;
                }
                return data.getFeatures().length / data.getRows();
            }, (a, b) -> {
                if (a == null) {
                    return b == null ? 0 : b;
                }
                if (b == null) {
                    return a;
                }
                return b;
            });
            MLPArchitecture architecture = new MLPArchitecture(cols);
            architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
            return architecture;
        };
        MLPTrainer trainer = new MLPTrainer(archSupplier, LossFunctions.L2, this.updatesStgy, this.maxIterations, this.batchSize, this.locIterations, this.seed);
        IgniteBiFunction lbExtractorWrapper = (k, v) -> new double[]{(Double)lbExtractor.apply(k, v)};
        if (mdl != null) {
            mlp = this.restoreMLPState(mdl);
            mlp = trainer.update(mlp, datasetBuilder, featureExtractor, lbExtractorWrapper);
        } else {
            mlp = trainer.fit((DatasetBuilder)datasetBuilder, (IgniteBiFunction)featureExtractor, lbExtractorWrapper);
        }
        double[] params = ((MultilayerPerceptron)mlp).parameters().getStorage().data();
        return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
    }

    @NotNull
    private MultilayerPerceptron restoreMLPState(LogisticRegressionModel mdl) {
        Vector weights = mdl.weights();
        double intercept = mdl.intercept();
        MLPArchitecture architecture1 = new MLPArchitecture(weights.size());
        MLPArchitecture architecture = architecture1 = architecture1.withAddedLayer(1, true, Activators.SIGMOID);
        MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture);
        Vector mlpState = weights.like(weights.size() + 1);
        weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get()));
        mlpState.set(mlpState.size() - 1, intercept);
        perceptron.setParameters(mlpState);
        return perceptron;
    }

    @Override
    protected boolean checkState(LogisticRegressionModel mdl) {
        return true;
    }

    public LogisticRegressionSGDTrainer<P> withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public LogisticRegressionSGDTrainer<P> withBatchSize(int batchSize) {
        this.batchSize = batchSize;
        return this;
    }

    public LogisticRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) {
        this.locIterations = amountOfLocIterations;
        return this;
    }

    public LogisticRegressionSGDTrainer<P> withSeed(long seed) {
        this.seed = seed;
        return this;
    }

    public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) {
        this.updatesStgy = updatesStgy;
        return this;
    }

    public UpdatesStrategy getUpdatesStgy() {
        return this.updatesStgy;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public int getLocIterations() {
        return this.locIterations;
    }

    public long getSeed() {
        return this.seed;
    }
}

