/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.regressions.linear;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.nn.Activators;
import org.apache.ignite.ml.nn.MLPTrainer;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class LinearRegressionSGDTrainer<P extends Serializable>
implements SingleLabelDatasetTrainer<LinearRegressionModel> {
    private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
    private final int maxIterations;
    private final int batchSize;
    private final int locIterations;
    private final long seed;

    public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) {
        this.updatesStgy = updatesStgy;
        this.maxIterations = maxIterations;
        this.batchSize = batchSize;
        this.locIterations = locIterations;
        this.seed = seed;
    }

    @Override
    public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, final IgniteBiFunction<K, V, Double> lbExtractor) {
        IgniteFunction archSupplier = dataset -> {
            int cols = (Integer)dataset.compute(data -> {
                if (data.getFeatures() == null) {
                    return null;
                }
                return data.getFeatures().length / data.getRows();
            }, (a, b) -> a == null ? b : a);
            MLPArchitecture architecture = new MLPArchitecture(cols);
            architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
            return architecture;
        };
        MLPTrainer<P> trainer = new MLPTrainer<P>(archSupplier, LossFunctions.MSE, this.updatesStgy, this.maxIterations, this.batchSize, this.locIterations, this.seed);
        IgniteBiFunction lbE = new IgniteBiFunction<K, V, double[]>(){

            @Override
            public double[] apply(K k, V v) {
                return new double[]{(Double)lbExtractor.apply(k, v)};
            }
        };
        Model mlp = trainer.fit((DatasetBuilder)datasetBuilder, (IgniteBiFunction)featureExtractor, lbE);
        double[] p = ((MultilayerPerceptron)mlp).parameters().getStorage().data();
        return new LinearRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]);
    }
}

