package org.apache.mahout.classifier.mlp;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.WritableUtils;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/mlp/NeuralNetwork.class */
public abstract class NeuralNetwork {
    private static final Logger log = LoggerFactory.getLogger(NeuralNetwork.class);
    public static final double DEFAULT_LEARNING_RATE = 0.5d;
    public static final double DEFAULT_REGULARIZATION_WEIGHT = 0.0d;
    public static final double DEFAULT_MOMENTUM_WEIGHT = 0.1d;
    protected String modelType;
    protected String modelPath;
    protected double learningRate;
    protected double regularizationWeight;
    protected double momentumWeight;
    protected String costFunctionName;
    protected List<Integer> layerSizeList;
    protected TrainingMethod trainingMethod;
    protected List<Matrix> weightMatrixList;
    protected List<Matrix> prevWeightUpdatesList;
    protected List<String> squashingFunctionList;
    protected int finalLayerIndex;

    /* loaded from: input_file:org/apache/mahout/classifier/mlp/NeuralNetwork$TrainingMethod.class */
    public enum TrainingMethod {
        GRADIENT_DESCENT
    }

    public NeuralNetwork() {
        log.info("Initialize model...");
        this.learningRate = 0.5d;
        this.regularizationWeight = 0.0d;
        this.momentumWeight = 0.1d;
        this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
        this.costFunctionName = "Minus_Squared";
        this.modelType = getClass().getSimpleName();
        this.layerSizeList = Lists.newArrayList();
        this.layerSizeList = Lists.newArrayList();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        this.squashingFunctionList = Lists.newArrayList();
    }

    public NeuralNetwork(double d, double d2, double d3) {
        this();
        setLearningRate(d);
        setMomentumWeight(d2);
        setRegularizationWeight(d3);
    }

    public NeuralNetwork(String str) throws IOException {
        this.modelPath = str;
        readFromModel();
    }

    public String getModelType() {
        return this.modelType;
    }

    public final NeuralNetwork setLearningRate(double d) {
        Preconditions.checkArgument(d > 0.0d, "Learning rate must be larger than 0.");
        this.learningRate = d;
        return this;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public final NeuralNetwork setRegularizationWeight(double d) {
        Preconditions.checkArgument(d >= 0.0d && d < 0.1d, "Regularization weight must be in range [0, 0.1)");
        this.regularizationWeight = d;
        return this;
    }

    public double getRegularizationWeight() {
        return this.regularizationWeight;
    }

    public final NeuralNetwork setMomentumWeight(double d) {
        Preconditions.checkArgument(d >= 0.0d && d <= 1.0d, "Momentum weight must be in range [0, 1.0]");
        this.momentumWeight = d;
        return this;
    }

    public double getMomentumWeight() {
        return this.momentumWeight;
    }

    public NeuralNetwork setTrainingMethod(TrainingMethod trainingMethod) {
        this.trainingMethod = trainingMethod;
        return this;
    }

    public TrainingMethod getTrainingMethod() {
        return this.trainingMethod;
    }

    public NeuralNetwork setCostFunction(String str) {
        this.costFunctionName = str;
        return this;
    }

    public int addLayer(int i, boolean z, String str) {
        Preconditions.checkArgument(i > 0, "Size of layer must be larger than 0.");
        log.info("Add layer with size {} and squashing function {}", Integer.valueOf(i), str);
        int i2 = i;
        if (!z) {
            i2++;
        }
        this.layerSizeList.add(Integer.valueOf(i2));
        int size = this.layerSizeList.size() - 1;
        if (z) {
            this.finalLayerIndex = size;
        }
        if (size > 0) {
            int intValue = this.layerSizeList.get(size - 1).intValue();
            int i3 = z ? i2 : i2 - 1;
            DenseMatrix denseMatrix = new DenseMatrix(i3, intValue);
            final RandomWrapper random = RandomUtils.getRandom();
            denseMatrix.assign(new DoubleFunction() { // from class: org.apache.mahout.classifier.mlp.NeuralNetwork.1
                @Override // org.apache.mahout.math.function.DoubleFunction
                public double apply(double d) {
                    return random.nextDouble() - 0.5d;
                }
            });
            this.weightMatrixList.add(denseMatrix);
            this.prevWeightUpdatesList.add(new DenseMatrix(i3, intValue));
            this.squashingFunctionList.add(str);
        }
        return size;
    }

    public int getLayerSize(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.layerSizeList.size(), String.format("Input must be in range [0, %d]\n", Integer.valueOf(this.layerSizeList.size() - 1)));
        return this.layerSizeList.get(i).intValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Integer> getLayerSizeList() {
        return this.layerSizeList;
    }

    public Matrix getWeightsByLayer(int i) {
        return this.weightMatrixList.get(i);
    }

    public void updateWeightMatrices(Matrix[] matrixArr) {
        for (int i = 0; i < matrixArr.length; i++) {
            this.weightMatrixList.set(i, this.weightMatrixList.get(i).plus(matrixArr[i]));
        }
    }

    public void setWeightMatrices(Matrix[] matrixArr) {
        this.weightMatrixList = Lists.newArrayList();
        Collections.addAll(this.weightMatrixList, matrixArr);
    }

    public void setWeightMatrix(int i, Matrix matrix) {
        Preconditions.checkArgument(0 <= i && i < this.weightMatrixList.size(), String.format("index [%s] should be in range [%s, %s).", Integer.valueOf(i), 0, Integer.valueOf(this.weightMatrixList.size())));
        this.weightMatrixList.set(i, matrix);
    }

    public Matrix[] getWeightMatrices() {
        Matrix[] matrixArr = new Matrix[this.weightMatrixList.size()];
        this.weightMatrixList.toArray(matrixArr);
        return matrixArr;
    }

    public Vector getOutput(Vector vector) {
        Preconditions.checkArgument(this.layerSizeList.get(0).intValue() == vector.size() + 1, String.format("The dimension of input instance should be %d, but the input has dimension %d.", Integer.valueOf(this.layerSizeList.get(0).intValue() - 1), Integer.valueOf(vector.size())));
        DenseVector denseVector = new DenseVector(vector.size() + 1);
        denseVector.set(0, 0.99999d);
        for (int i = 1; i < denseVector.size(); i++) {
            denseVector.set(i, vector.get(i - 1));
        }
        List<Vector> outputInternal = getOutputInternal(denseVector);
        Vector vector2 = outputInternal.get(outputInternal.size() - 1);
        return vector2.viewPart(1, vector2.size() - 1);
    }

    protected List<Vector> getOutputInternal(Vector vector) {
        ArrayList newArrayList = Lists.newArrayList();
        Vector vector2 = vector;
        newArrayList.add(vector2);
        for (int i = 0; i < this.layerSizeList.size() - 1; i++) {
            vector2 = forward(i, vector2);
            newArrayList.add(vector2);
        }
        return newArrayList;
    }

    protected Vector forward(int i, Vector vector) {
        Vector assign = this.weightMatrixList.get(i).times(vector).assign(NeuralNetworkFunctions.getDoubleFunction(this.squashingFunctionList.get(i)));
        DenseVector denseVector = new DenseVector(assign.size() + 1);
        denseVector.set(0, 1.0d);
        for (int i2 = 0; i2 < assign.size(); i2++) {
            denseVector.set(i2 + 1, assign.get(i2));
        }
        return denseVector;
    }

    public void trainOnline(Vector vector) {
        updateWeightMatrices(trainByInstance(vector));
    }

    public Matrix[] trainByInstance(Vector vector) {
        int intValue = this.layerSizeList.get(0).intValue() - 1;
        int intValue2 = this.layerSizeList.get(this.layerSizeList.size() - 1).intValue();
        Preconditions.checkArgument(intValue + intValue2 == vector.size(), String.format("The dimension of training instance is %d, but requires %d.", Integer.valueOf(vector.size()), Integer.valueOf(intValue + intValue2)));
        if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
            return trainByInstanceGradientDescent(vector);
        }
        throw new IllegalArgumentException("Training method is not supported.");
    }

    private Matrix[] trainByInstanceGradientDescent(Vector vector) {
        int intValue = this.layerSizeList.get(0).intValue() - 1;
        DenseVector denseVector = new DenseVector(this.layerSizeList.get(0).intValue());
        denseVector.set(0, 1.0d);
        for (int i = 0; i < intValue; i++) {
            denseVector.set(i + 1, vector.get(i));
        }
        Vector viewPart = vector.viewPart(denseVector.size() - 1, (vector.size() - denseVector.size()) + 1);
        Matrix[] matrixArr = new Matrix[this.weightMatrixList.size()];
        for (int i2 = 0; i2 < matrixArr.length; i2++) {
            matrixArr[i2] = new DenseMatrix(this.weightMatrixList.get(i2).rowSize(), this.weightMatrixList.get(i2).columnSize());
        }
        List<Vector> outputInternal = getOutputInternal(denseVector);
        Vector denseVector2 = new DenseVector(this.layerSizeList.get(this.layerSizeList.size() - 1).intValue());
        Vector vector2 = outputInternal.get(outputInternal.size() - 1);
        DoubleFunction derivativeDoubleFunction = NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(this.squashingFunctionList.size() - 1));
        DoubleDoubleFunction derivativeDoubleDoubleFunction = NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(this.costFunctionName);
        Matrix matrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1);
        for (int i3 = 0; i3 < denseVector2.size(); i3++) {
            denseVector2.set(i3, derivativeDoubleDoubleFunction.apply(viewPart.get(i3), vector2.get(i3 + 1)) + (this.regularizationWeight * matrix.viewRow(i3).zSum()));
            denseVector2.set(i3, denseVector2.get(i3) * derivativeDoubleFunction.apply(vector2.get(i3 + 1)));
        }
        for (int size = this.layerSizeList.size() - 2; size >= 0; size--) {
            denseVector2 = backPropagate(size, denseVector2, outputInternal, matrixArr[size]);
        }
        this.prevWeightUpdatesList = Arrays.asList(matrixArr);
        return matrixArr;
    }

    private Vector backPropagate(int i, Vector vector, List<Vector> list, Matrix matrix) {
        final DoubleFunction derivativeDoubleFunction = NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(i));
        Vector vector2 = list.get(i);
        Matrix matrix2 = this.weightMatrixList.get(i);
        Matrix matrix3 = this.prevWeightUpdatesList.get(i);
        if (i != this.layerSizeList.size() - 2) {
            vector = vector.viewPart(1, vector.size() - 1);
        }
        Vector assign = matrix2.transpose().times(vector).assign(vector2, new DoubleDoubleFunction() { // from class: org.apache.mahout.classifier.mlp.NeuralNetwork.2
            @Override // org.apache.mahout.math.function.DoubleDoubleFunction
            public double apply(double d, double d2) {
                return d * derivativeDoubleFunction.apply(d2);
            }
        });
        for (int i2 = 0; i2 < matrix.rowSize(); i2++) {
            for (int i3 = 0; i3 < matrix.columnSize(); i3++) {
                matrix.set(i2, i3, ((-this.learningRate) * vector.get(i2) * vector2.get(i3)) + (this.momentumWeight * matrix3.get(i2, i3)));
            }
        }
        return assign;
    }

    protected void readFromModel() throws IOException {
        log.info("Load model from {}", this.modelPath);
        Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
        FSDataInputStream fSDataInputStream = null;
        try {
            Path path = new Path(this.modelPath);
            fSDataInputStream = new FSDataInputStream(path.getFileSystem(new Configuration()).open(path));
            readFields(fSDataInputStream);
            Closeables.close(fSDataInputStream, true);
        } catch (Throwable th) {
            Closeables.close(fSDataInputStream, true);
            throw th;
        }
    }

    public void writeModelToFile() throws IOException {
        log.info("Write model to {}.", this.modelPath);
        Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
        FSDataOutputStream fSDataOutputStream = null;
        try {
            Path path = new Path(this.modelPath);
            fSDataOutputStream = path.getFileSystem(new Configuration()).create(path, true);
            write(fSDataOutputStream);
            Closeables.close(fSDataOutputStream, false);
        } catch (Throwable th) {
            Closeables.close(fSDataOutputStream, false);
            throw th;
        }
    }

    public void setModelPath(String str) {
        this.modelPath = str;
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public void write(DataOutput dataOutput) throws IOException {
        WritableUtils.writeString(dataOutput, this.modelType);
        dataOutput.writeDouble(this.learningRate);
        if (this.modelPath != null) {
            WritableUtils.writeString(dataOutput, this.modelPath);
        } else {
            WritableUtils.writeString(dataOutput, "null");
        }
        dataOutput.writeDouble(this.regularizationWeight);
        dataOutput.writeDouble(this.momentumWeight);
        WritableUtils.writeString(dataOutput, this.costFunctionName);
        dataOutput.writeInt(this.layerSizeList.size());
        Iterator<Integer> it = this.layerSizeList.iterator();
        while (it.hasNext()) {
            dataOutput.writeInt(it.next().intValue());
        }
        WritableUtils.writeEnum(dataOutput, this.trainingMethod);
        dataOutput.writeInt(this.squashingFunctionList.size());
        Iterator<String> it2 = this.squashingFunctionList.iterator();
        while (it2.hasNext()) {
            WritableUtils.writeString(dataOutput, it2.next());
        }
        dataOutput.writeInt(this.weightMatrixList.size());
        Iterator<Matrix> it3 = this.weightMatrixList.iterator();
        while (it3.hasNext()) {
            MatrixWritable.writeMatrix(dataOutput, it3.next());
        }
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.modelType = WritableUtils.readString(dataInput);
        if (!this.modelType.equals(getClass().getSimpleName())) {
            throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
        }
        this.learningRate = dataInput.readDouble();
        this.modelPath = WritableUtils.readString(dataInput);
        if (this.modelPath.equals("null")) {
            this.modelPath = null;
        }
        this.regularizationWeight = dataInput.readDouble();
        this.momentumWeight = dataInput.readDouble();
        this.costFunctionName = WritableUtils.readString(dataInput);
        int readInt = dataInput.readInt();
        this.layerSizeList = Lists.newArrayList();
        for (int i = 0; i < readInt; i++) {
            this.layerSizeList.add(Integer.valueOf(dataInput.readInt()));
        }
        this.trainingMethod = (TrainingMethod) WritableUtils.readEnum(dataInput, TrainingMethod.class);
        int readInt2 = dataInput.readInt();
        this.squashingFunctionList = Lists.newArrayList();
        for (int i2 = 0; i2 < readInt2; i2++) {
            this.squashingFunctionList.add(WritableUtils.readString(dataInput));
        }
        int readInt3 = dataInput.readInt();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        for (int i3 = 0; i3 < readInt3; i3++) {
            Matrix readMatrix = MatrixWritable.readMatrix(dataInput);
            this.weightMatrixList.add(readMatrix);
            this.prevWeightUpdatesList.add(new DenseMatrix(readMatrix.rowSize(), readMatrix.columnSize()));
        }
    }
}
