/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.nn;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.nn.initializers.RandomInitializer;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.util.Utils;

public class MLPTrainer<P extends Serializable>
implements MultiLabelDatasetTrainer<MultilayerPerceptron> {
    private final IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier;
    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
    private final int maxIterations;
    private final int batchSize;
    private final int locIterations;
    private final long seed;

    public MLPTrainer(MLPArchitecture arch, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) {
        this(dataset -> arch, loss, updatesStgy, maxIterations, batchSize, locIterations, seed);
    }

    public MLPTrainer(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) {
        this.archSupplier = archSupplier;
        this.loss = loss;
        this.updatesStgy = updatesStgy;
        this.maxIterations = maxIterations;
        this.batchSize = batchSize;
        this.locIterations = locIterations;
        this.seed = seed;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
        try (Dataset dataset = datasetBuilder.build(new EmptyContextBuilder(), new SimpleLabeledDatasetDataBuilder(featureExtractor, lbExtractor));){
            MLPArchitecture arch = (MLPArchitecture)this.archSupplier.apply(dataset);
            MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(this.seed));
            ParameterUpdateCalculator<MultilayerPerceptron, Serializable> updater = this.updatesStgy.getUpdatesCalculator();
            for (int i = 0; i < this.maxIterations; i += this.locIterations) {
                MultilayerPerceptron finalMdl = mdl;
                int finalI = i;
                List totUp = (List)dataset.compute(data -> {
                    Object update = updater.init(finalMdl, this.loss);
                    MultilayerPerceptron mlp = Utils.copy(finalMdl);
                    if (data.getFeatures() != null) {
                        ArrayList updates = new ArrayList();
                        for (int locStep = 0; locStep < this.locIterations; ++locStep) {
                            int[] rows = Utils.selectKDistinct(data.getRows(), Math.min(this.batchSize, data.getRows()), new Random(this.seed ^ (long)(finalI * locStep)));
                            double[] inputsBatch = MLPTrainer.batch(data.getFeatures(), rows, data.getRows());
                            double[] groundTruthBatch = MLPTrainer.batch(data.getLabels(), rows, data.getRows());
                            DenseLocalOnHeapMatrix inputs = new DenseLocalOnHeapMatrix(inputsBatch, rows.length, 0);
                            DenseLocalOnHeapMatrix groundTruth = new DenseLocalOnHeapMatrix(groundTruthBatch, rows.length, 0);
                            update = updater.calculateNewUpdate(mlp, (Serializable)update, locStep, inputs.transpose(), groundTruth.transpose());
                            mlp = updater.update(mlp, (Serializable)update);
                            updates.add(update);
                        }
                        ArrayList res = new ArrayList();
                        res.add(this.updatesStgy.locStepUpdatesReducer().apply(updates));
                        return res;
                    }
                    return null;
                }, (a, b) -> {
                    if (a == null) {
                        return b;
                    }
                    if (b == null) {
                        return a;
                    }
                    a.addAll(b);
                    return a;
                });
                Serializable update = (Serializable)this.updatesStgy.allUpdatesReducer().apply(totUp);
                mdl = updater.update(mdl, update);
            }
            MultilayerPerceptron multilayerPerceptron = mdl;
            return multilayerPerceptron;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    static double[] batch(double[] data, int[] rows, int totalRows) {
        int cols = data.length / totalRows;
        double[] res = new double[cols * rows.length];
        for (int i = 0; i < rows.length; ++i) {
            for (int j = 0; j < cols; ++j) {
                res[j * rows.length + i] = data[j * totalRows + rows[i]];
            }
        }
        return res;
    }
}

