/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.nn;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableDoubleToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.util.MatrixUtil;
import org.apache.ignite.ml.nn.MLPLayer;
import org.apache.ignite.ml.nn.MLPState;
import org.apache.ignite.ml.nn.ReplicatedVectorMatrix;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.nn.architecture.TransformationLayerArchitecture;
import org.apache.ignite.ml.nn.initializers.MLPInitializer;
import org.apache.ignite.ml.nn.initializers.RandomInitializer;
import org.apache.ignite.ml.optimization.SmoothParametrized;

public class MultilayerPerceptron
implements Model<Matrix, Matrix>,
SmoothParametrized<MultilayerPerceptron>,
Serializable {
    protected MLPArchitecture architecture;
    protected List<MLPLayer> layers;
    protected MultilayerPerceptron below;

    public MultilayerPerceptron(MLPArchitecture arch, MLPInitializer initializer) {
        this.layers = new ArrayList<MLPLayer>(arch.layersCount() + 1);
        this.architecture = arch;
        this.below = null;
        this.initLayers(initializer != null ? initializer : new RandomInitializer(new Random()));
    }

    public MultilayerPerceptron(MLPArchitecture arch) {
        this(arch, null);
    }

    private void initLayers(MLPInitializer initializer) {
        int prevSize = this.architecture.inputSize();
        for (int i = 1; i < this.architecture.layersCount(); ++i) {
            TransformationLayerArchitecture layerCfg = this.architecture.transformationLayerArchitecture(i);
            int neuronsCnt = layerCfg.neuronsCount();
            DenseMatrix weights = new DenseMatrix(neuronsCnt, prevSize);
            initializer.initWeights(weights);
            DenseVector biases = null;
            if (layerCfg.hasBias()) {
                biases = new DenseVector(neuronsCnt);
                initializer.initBiases(biases);
            }
            this.layers.add(new MLPLayer(weights, biases));
            prevSize = layerCfg.neuronsCount();
        }
    }

    protected MultilayerPerceptron(MultilayerPerceptron above, MultilayerPerceptron below) {
        this.layers = above.layers;
        this.architecture = above.architecture;
        this.below = below;
    }

    public MLPState computeState(Matrix val) {
        MLPState res = new MLPState(val);
        this.forwardPass(val, res, true);
        return res;
    }

    public Matrix forwardPass(Matrix val, MLPState state, boolean writeState) {
        Matrix res = val;
        if (this.below != null) {
            res = this.below.forwardPass(val, state, writeState);
        }
        for (int i = 1; i < this.architecture.layersCount(); ++i) {
            MLPLayer curLayer = this.layers.get(i - 1);
            res = curLayer.weights.times(res);
            TransformationLayerArchitecture layerCfg = this.architecture.transformationLayerArchitecture(i);
            if (layerCfg.hasBias()) {
                ReplicatedVectorMatrix biasesMatrix = new ReplicatedVectorMatrix(this.biases(i), res.columnSize(), true);
                res = res.plus(biasesMatrix);
            }
            state.linearOutput.add(res);
            if (writeState) {
                res = res.copy();
            }
            res = res.map(layerCfg.activationFunction());
            state.activatorsOutput.add(res);
        }
        return res;
    }

    @Override
    public Matrix apply(Matrix val) {
        MLPState state = new MLPState(null);
        this.forwardPass(val.transpose(), state, false);
        return state.activatorsOutput.get(state.activatorsOutput.size() - 1).transpose();
    }

    public MultilayerPerceptron add(MultilayerPerceptron above) {
        return new MultilayerPerceptron(above, this);
    }

    public Matrix weights(int layerIdx) {
        assert (layerIdx >= 1);
        assert (layerIdx < this.architecture.layersCount() || this.below != null);
        if (layerIdx < this.belowLayersCount()) {
            return this.below.weights(layerIdx - this.architecture.layersCount());
        }
        return this.layers.get((int)(layerIdx - this.belowLayersCount() - 1)).weights;
    }

    public Vector biases(int layerIdx) {
        assert (layerIdx >= 0);
        assert (layerIdx < this.architecture.layersCount() || this.below != null);
        if (layerIdx < this.belowLayersCount()) {
            return this.below.biases(layerIdx - this.architecture.layersCount());
        }
        return this.layers.get((int)(layerIdx - this.belowLayersCount() - 1)).biases;
    }

    public boolean hasBiases(int layerIdx) {
        return layerIdx != 0 && this.biases(layerIdx) != null;
    }

    public MultilayerPerceptron setBiases(int layerIdx, Vector bias) {
        this.biases(layerIdx).assign(bias);
        return this;
    }

    public MultilayerPerceptron setBias(int layerIdx, int neuronIdx, double val) {
        assert (layerIdx > 0);
        assert (this.architecture.transformationLayerArchitecture(layerIdx).hasBias());
        this.biases(layerIdx).setX(neuronIdx, val);
        return this;
    }

    public double bias(int layerIdx, int neuronIdx) {
        assert (layerIdx > 0);
        assert (this.architecture.transformationLayerArchitecture(layerIdx).hasBias());
        return this.biases(layerIdx).getX(neuronIdx);
    }

    public MultilayerPerceptron setWeights(int layerIdx, Matrix weights) {
        this.weights(layerIdx).assign(weights);
        return this;
    }

    public MultilayerPerceptron setWeight(int layerIdx, int fromNeuron, int toNeuron, double val) {
        assert (layerIdx > 0);
        this.weights(layerIdx).setX(toNeuron, fromNeuron, val);
        return this;
    }

    public double weight(int layerIdx, int fromNeuron, int toNeuron) {
        assert (layerIdx > 0);
        assert (this.architecture.transformationLayerArchitecture(layerIdx).hasBias());
        return this.weights(layerIdx).getX(fromNeuron, toNeuron);
    }

    public int layersCount() {
        return this.architecture.layersCount() + (this.below != null ? this.below.layersCount() : 0);
    }

    protected int belowLayersCount() {
        return this.below != null ? this.below.layersCount() : 0;
    }

    public MLPArchitecture architecture() {
        if (this.below != null) {
            return this.below.architecture().add(this.architecture);
        }
        return this.architecture;
    }

    @Override
    public Vector differentiateByParameters(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, Matrix inputsBatch, Matrix truthBatch) {
        int batchSize = inputsBatch.columnSize();
        double invBatchSize = 1.0 / (double)batchSize;
        int lastLayer = this.layersCount() - 1;
        MLPState mlpState = this.computeState(inputsBatch);
        Matrix dz = null;
        LinkedList<MLPLayer> layersParameters = new LinkedList<MLPLayer>();
        for (int layer = lastLayer; layer > 0; --layer) {
            Matrix z = mlpState.linearOutput(layer).copy();
            Matrix dSigmaDz = this.differentiateNonlinearity(z, this.architecture().transformationLayerArchitecture(layer).activationFunction());
            if (layer == lastLayer) {
                Matrix sigma = mlpState.activatorsOutput(lastLayer).copy();
                Matrix dLossDSigma = this.differentiateLoss(truthBatch, sigma, loss);
                dz = MatrixUtil.elementWiseTimes(dLossDSigma, dSigmaDz);
            } else {
                dz = this.weights(layer + 1).transpose().times(dz);
                dz = MatrixUtil.elementWiseTimes(dz, dSigmaDz);
            }
            Matrix a = mlpState.activatorsOutput(layer - 1);
            Matrix dw = dz.times(a.transpose()).times(invBatchSize);
            Vector db = null;
            if (this.hasBiases(layer)) {
                db = dz.foldRows(Vector::sum).times(invBatchSize);
            }
            layersParameters.add(0, new MLPLayer(dw, db));
        }
        return this.paramsAsVector(layersParameters);
    }

    @Override
    public Vector parameters() {
        return this.paramsAsVector(this.layers);
    }

    protected Vector paramsAsVector(List<MLPLayer> layersParams) {
        int off = 0;
        DenseVector res = new DenseVector(this.architecture().parametersCount());
        for (MLPLayer layerParams : layersParams) {
            off = this.writeToVector((Vector)res, layerParams.weights, off);
            if (layerParams.biases == null) continue;
            off = this.writeToVector((Vector)res, layerParams.biases, off);
        }
        return res;
    }

    @Override
    public MultilayerPerceptron setParameters(Vector vector) {
        int off = 0;
        for (int l = 1; l < this.layersCount(); ++l) {
            MLPLayer layer = this.layers.get(l - 1);
            IgniteBiTuple<Integer, Matrix> readRes = this.readFromVector(vector, layer.weights.rowSize(), layer.weights.columnSize(), off);
            off = (Integer)readRes.get1();
            layer.weights = (Matrix)readRes.get2();
            if (!this.hasBiases(l)) continue;
            IgniteBiTuple<Integer, Vector> readRes1 = this.readFromVector(vector, layer.biases.size(), off);
            off = (Integer)readRes1.get1();
            layer.biases = (Vector)readRes1.get2();
        }
        return this;
    }

    @Override
    public int parametersCount() {
        return this.architecture().parametersCount();
    }

    private IgniteBiTuple<Integer, Matrix> readFromVector(Vector v, int rows, int cols, int off) {
        DenseMatrix mtx = new DenseMatrix(rows, cols);
        int size = rows * cols;
        for (int i = 0; i < size; ++i) {
            mtx.setX(i / cols, i % cols, v.getX(off + i));
        }
        return new IgniteBiTuple((Object)(off + size), (Object)mtx);
    }

    private IgniteBiTuple<Integer, Vector> readFromVector(Vector v, int size, int off) {
        DenseVector vec = new DenseVector(size);
        for (int i = 0; i < size; ++i) {
            vec.setX(i, v.getX(off + i));
        }
        return new IgniteBiTuple((Object)(off + size), (Object)vec);
    }

    private int writeToVector(Vector vec, Matrix mtx, int off) {
        int rows = mtx.rowSize();
        int cols = mtx.columnSize();
        for (int r = 0; r < rows; ++r) {
            for (int c = 0; c < cols; ++c) {
                vec.setX(off, mtx.getX(r, c));
                ++off;
            }
        }
        return off;
    }

    private int writeToVector(Vector vec, Vector v, int off) {
        for (int i = 0; i < v.size(); ++i) {
            vec.setX(off, v.getX(i));
            ++off;
        }
        return off;
    }

    private Matrix differentiateLoss(Matrix groundTruth, Matrix lastLayerOutput, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
        Matrix diff = groundTruth.like(groundTruth.rowSize(), groundTruth.columnSize());
        for (int col = 0; col < groundTruth.columnSize(); ++col) {
            Vector gtCol = groundTruth.getCol(col);
            Vector predCol = lastLayerOutput.getCol(col);
            diff.assignColumn(col, ((IgniteDifferentiableVectorToDoubleFunction)loss.apply(gtCol)).differential(predCol));
        }
        return diff;
    }

    private Matrix differentiateNonlinearity(Matrix linearOut, IgniteDifferentiableDoubleToDoubleFunction nonlinearity) {
        Matrix diff = linearOut.copy();
        diff.map(nonlinearity::differential);
        return diff;
    }

    public String toString() {
        return this.toString(false);
    }

    @Override
    public String toString(boolean pretty) {
        StringBuilder builder = new StringBuilder("MultilayerPerceptron [\n");
        if (this.below != null) {
            builder.append("below = \n").append(this.below.toString(pretty)).append("\n\n");
        }
        builder.append("layers = [").append(pretty ? "\n" : "");
        for (int i = 0; i < this.layers.size(); ++i) {
            MLPLayer layer = this.layers.get(i);
            builder.append("\tlayer").append(i).append(" = [\n");
            if (layer.biases != null) {
                builder.append("\t\tbias = ").append(Tracer.asAscii(layer.biases, "%.4f", false)).append("\n");
            }
            String matrix = Tracer.asAscii(layer.weights, "%.4f").replaceAll("\n", "\n\t\t\t");
            builder.append("\t\tweights = [\n\t\t\t").append(matrix).append("\n\t\t]");
            builder.append("\n\t]\n");
        }
        builder.append("]");
        return builder.toString();
    }
}

