package org.apache.mahout.classifier.mlp;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Closeables;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.shell.Ls;
import org.apache.hadoop.util.StringUtils;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.Arrays;
import org.apache.mahout.math.DenseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.class */
public final class TrainMultilayerPerceptron {
    private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron$Parameters.class */
    public static class Parameters {
        double learningRate;
        double momemtumWeight;
        double regularizationWeight;
        String inputFilePath;
        boolean skipHeader;
        String modelFilePath;
        boolean updateModel;
        String squashingFunctionName;
        Map<String, Integer> labelsIndex = Maps.newHashMap();
        List<Integer> layerSizeList = Lists.newArrayList();

        Parameters() {
        }
    }

    public static void main(String[] strArr) throws Exception {
        MultilayerPerceptron multilayerPerceptron;
        Parameters parameters = new Parameters();
        if (!parseArgs(strArr, parameters)) {
            return;
        }
        log.info("Validate model...");
        Path path = new Path(parameters.modelFilePath);
        FileSystem fileSystem = path.getFileSystem(new Configuration());
        if (fileSystem.exists(path) && parameters.updateModel) {
            log.info("Build model from existing model...");
            multilayerPerceptron = new MultilayerPerceptron(parameters.modelFilePath);
        } else {
            if (fileSystem.exists(path)) {
                fileSystem.delete(path, true);
            }
            log.info("Build model from scratch...");
            multilayerPerceptron = new MultilayerPerceptron();
            for (int i = 0; i < parameters.layerSizeList.size(); i++) {
                if (i != parameters.layerSizeList.size() - 1) {
                    multilayerPerceptron.addLayer(parameters.layerSizeList.get(i).intValue(), false, parameters.squashingFunctionName);
                } else {
                    multilayerPerceptron.addLayer(parameters.layerSizeList.get(i).intValue(), true, parameters.squashingFunctionName);
                }
                multilayerPerceptron.setCostFunction("Minus_Squared");
                multilayerPerceptron.setLearningRate(parameters.learningRate).setMomentumWeight(parameters.momemtumWeight).setRegularizationWeight(parameters.regularizationWeight);
            }
            multilayerPerceptron.setModelPath(parameters.modelFilePath);
        }
        multilayerPerceptron.setLearningRate(parameters.learningRate).setMomentumWeight(parameters.momemtumWeight).setRegularizationWeight(parameters.regularizationWeight);
        Path path2 = new Path(parameters.inputFilePath);
        FileSystem fileSystem2 = path2.getFileSystem(new Configuration());
        Preconditions.checkArgument(fileSystem2.exists(path2), "Training dataset %s cannot be found!", parameters.inputFilePath);
        log.info("Read data and train model...");
        BufferedReader bufferedReader = null;
        try {
            bufferedReader = new BufferedReader(new InputStreamReader(fileSystem2.open(path2)));
            if (parameters.skipHeader) {
                bufferedReader.readLine();
            }
            int size = parameters.labelsIndex.size();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    log.info("Write trained model to {}", parameters.modelFilePath);
                    multilayerPerceptron.writeModelToFile();
                    multilayerPerceptron.close();
                    Closeables.close(bufferedReader, true);
                    return;
                }
                String[] split = readLine.split(StringUtils.COMMA_STR);
                int intValue = parameters.labelsIndex.get(split[split.length - 1]).intValue();
                double[] dArr = new double[(split.length - 1) + size];
                for (int i2 = 0; i2 < split.length - 1; i2++) {
                    dArr[i2] = Double.parseDouble(split[i2]);
                }
                for (int i3 = 0; i3 < size; i3++) {
                    dArr[(split.length - 1) + i3] = 0.0d;
                }
                dArr[(split.length - 1) + intValue] = 1.0d;
                multilayerPerceptron.trainOnline(new DenseVector(dArr).viewPart(0, dArr.length));
            }
        } catch (Throwable th) {
            Closeables.close(bufferedReader, true);
            throw th;
        }
    }

    private static boolean parseArgs(String[] strArr, Parameters parameters) throws Exception {
        log.info("Validate and parse arguments...");
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("skipHeader").withShortName(WikipediaTokenizer.SUB_HEADING).create();
        DefaultOption create2 = defaultOptionBuilder.withLongName(DefaultOptionCreator.INPUT_OPTION).withShortName(WikipediaTokenizer.ITALICS).withRequired(true).withChildren(groupBuilder.withOption(create).create()).withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("the file path of training dataset").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName("labels").withShortName("labels").withRequired(true).withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()).withDescription("label names").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("update").withShortName("u").withDescription("whether to incrementally update model if the model exists").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("model").withShortName("mo").withRequired(true).withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create()).withDescription("the path to store the trained model").withChildren(groupBuilder.withOption(create4).create()).create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("layerSize").withShortName(Ls.NAME).withRequired(true).withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create()).withDescription("the size of each layer").create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("squashingFunction").withShortName("sf").withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1).withDefault("Sigmoid").create()).withDescription("the name of squashing function (currently only supports Sigmoid)").create();
        DefaultOption create8 = defaultOptionBuilder.withLongName("learningRate").withShortName("l").withArgument(argumentBuilder.withName("learning rate").withMaximum(1).withMinimum(1).withDefault(Double.valueOf(0.5d)).create()).withDescription("learning rate").create();
        DefaultOption create9 = defaultOptionBuilder.withLongName("momemtumWeight").withShortName(FuzzyKMeansDriver.M_OPTION).withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1).withMinimum(1).withDefault(Double.valueOf(0.1d)).create()).withDescription("momemtum weight").create();
        DefaultOption create10 = defaultOptionBuilder.withLongName("regularizationWeight").withShortName("r").withArgument(argumentBuilder.withName("regularization weight").withMaximum(1).withMinimum(1).withDefault(Double.valueOf(0.0d)).create()).withDescription("regularization weight").create();
        Parser parser = new Parser();
        parser.setGroup(groupBuilder.withOption(create2).withOption(create).withOption(create4).withOption(create3).withOption(create5).withOption(create6).withOption(create7).withOption(create8).withOption(create9).withOption(create10).create());
        CommandLine parseAndHelp = parser.parseAndHelp(strArr);
        if (parseAndHelp == null) {
            return false;
        }
        parameters.learningRate = getDouble(parseAndHelp, create8).doubleValue();
        parameters.momemtumWeight = getDouble(parseAndHelp, create9).doubleValue();
        parameters.regularizationWeight = getDouble(parseAndHelp, create10).doubleValue();
        parameters.inputFilePath = getString(parseAndHelp, create2);
        parameters.skipHeader = parseAndHelp.hasOption(create);
        int i = 0;
        Iterator<String> it = getStringList(parseAndHelp, create3).iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            parameters.labelsIndex.put(it.next(), Integer.valueOf(i2));
        }
        parameters.modelFilePath = getString(parseAndHelp, create5);
        parameters.updateModel = parseAndHelp.hasOption(create4);
        parameters.layerSizeList = getIntegerList(parseAndHelp, create6);
        parameters.squashingFunctionName = getString(parseAndHelp, create7);
        System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f, Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath, Boolean.valueOf(parameters.updateModel), Arrays.toString(parameters.layerSizeList.toArray()), parameters.squashingFunctionName, Double.valueOf(parameters.learningRate), Double.valueOf(parameters.momemtumWeight), Double.valueOf(parameters.regularizationWeight));
        return true;
    }

    static Double getDouble(CommandLine commandLine, Option option) {
        Object value = commandLine.getValue(option);
        if (value != null) {
            return Double.valueOf(Double.parseDouble(value.toString()));
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String getString(CommandLine commandLine, Option option) {
        Object value = commandLine.getValue(option);
        if (value != null) {
            return value.toString();
        }
        return null;
    }

    static List<Integer> getIntegerList(CommandLine commandLine, Option option) {
        List values = commandLine.getValues(option);
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it = values.iterator();
        while (it.hasNext()) {
            newArrayList.add(Integer.valueOf(Integer.parseInt((String) it.next())));
        }
        return newArrayList;
    }

    static List<String> getStringList(CommandLine commandLine, Option option) {
        return commandLine.getValues(option);
    }
}
