/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.nn.architecture;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableDoubleToDoubleFunction;
import org.apache.ignite.ml.nn.architecture.LayerArchitecture;
import org.apache.ignite.ml.nn.architecture.TransformationLayerArchitecture;

public class MLPArchitecture
implements Serializable {
    private final List<LayerArchitecture> layers;

    public MLPArchitecture(int inputSize) {
        this.layers = new ArrayList<LayerArchitecture>();
        this.layers.add(new LayerArchitecture(inputSize));
    }

    private MLPArchitecture(List<LayerArchitecture> layers) {
        this.layers = layers;
    }

    public int layersCount() {
        return this.layers.size();
    }

    public int inputSize() {
        return this.layers.get(0).neuronsCount();
    }

    public int outputSize() {
        return this.layers.get(this.layersCount() - 1).neuronsCount();
    }

    public MLPArchitecture withAddedLayer(int neuronsCnt, boolean hasBias, IgniteDifferentiableDoubleToDoubleFunction f) {
        ArrayList<LayerArchitecture> newLayers = new ArrayList<LayerArchitecture>(this.layers);
        newLayers.add(new TransformationLayerArchitecture(neuronsCnt, hasBias, f));
        return new MLPArchitecture(newLayers);
    }

    public LayerArchitecture layerArchitecture(int layer) {
        return this.layers.get(layer);
    }

    public TransformationLayerArchitecture transformationLayerArchitecture(int layer) {
        return (TransformationLayerArchitecture)this.layers.get(layer);
    }

    public MLPArchitecture add(MLPArchitecture second) {
        assert (second.inputSize() == this.outputSize());
        MLPArchitecture res = new MLPArchitecture(this.inputSize());
        res.layers.addAll(this.layers);
        res.layers.addAll(second.layers);
        return res;
    }

    public int parametersCount() {
        int res = 0;
        for (int i = 1; i < this.layersCount(); ++i) {
            TransformationLayerArchitecture la = this.transformationLayerArchitecture(i);
            res += this.layerArchitecture(i - 1).neuronsCount() * la.neuronsCount();
            if (!la.hasBias()) continue;
            res += la.neuronsCount();
        }
        return res;
    }
}

