/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.tools;

import au.com.bytecode.opencsv.CSVReader;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmMojoModel;
import hex.genmodel.algos.pca.PCAMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.OrdinalModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;

public class PredictCsv {
    private String inputCSVFileName;
    private String outputCSVFileName;
    private boolean useDecimalOutput = false;
    public char separator = (char)44;
    public boolean setInvNumNA = false;
    public boolean getTreePath = false;
    boolean returnGLRMReconstruct = false;
    public int glrmIterNumber = -1;
    private EasyPredictModelWrapper model;

    public static void main(String[] args) {
        PredictCsv main = new PredictCsv();
        main.parseArgs(args);
        try {
            main.run();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(2);
        }
        System.exit(0);
    }

    private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
        RowData row = new RowData();
        int maxI = Math.min(inputColumnNames.length, splitLine.length);
        block9: for (int i = 0; i < maxI; ++i) {
            String cellData;
            String columnName = inputColumnNames[i];
            switch (cellData = splitLine[i]) {
                case "": 
                case "NA": 
                case "N/A": 
                case "-": {
                    continue block9;
                }
                default: {
                    row.put(columnName, cellData);
                }
            }
        }
        return row;
    }

    private String myDoubleToString(double d) {
        if (Double.isNaN(d)) {
            return "NA";
        }
        return this.useDecimalOutput ? Double.toString(d) : Double.toHexString(d);
    }

    private void writeTreePathNames(BufferedWriter output) throws Exception {
        String[] columnNames = ((SharedTreeMojoModel)this.model.m).getDecisionPathNames();
        int lastIndex = columnNames.length - 1;
        for (int index = 0; index < lastIndex; ++index) {
            output.write(columnNames[index]);
            output.write(",");
        }
        output.write(columnNames[lastIndex]);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void run() throws Exception {
        ModelCategory category = this.model.getModelCategory();
        CSVReader reader = new CSVReader((Reader)new FileReader(this.inputCSVFileName), this.separator);
        BufferedWriter output = new BufferedWriter(new FileWriter(this.outputCSVFileName));
        int lastCommaAutoEn = -1;
        switch (category) {
            case AutoEncoder: {
                String[] cnames = this.model.m.getNames();
                int numCats = this.model.domainMap.size();
                int numNums = this.model.m.nfeatures() - numCats;
                String[][] domainValues = this.model.m.getDomainValues();
                int lastCatIdx = numCats - 1;
                for (int index = 0; index <= lastCatIdx; ++index) {
                    String[] tdomains = domainValues[index];
                    int tdomainLen = tdomains.length - 1;
                    for (int index2 = 0; index2 <= tdomainLen; ++index2) {
                        ++lastCommaAutoEn;
                        String temp = "reconstr_" + tdomains[index2];
                        output.write(temp);
                        output.write(44);
                    }
                    ++lastCommaAutoEn;
                    String temp = "reconstr_" + cnames[index] + ".missing(NA)";
                    output.write(temp);
                    if (numNums <= 0 && index >= lastCatIdx) continue;
                    output.write(44);
                }
                int lastComma = cnames.length - 1;
                for (int index = numCats; index < cnames.length; ++index) {
                    ++lastCommaAutoEn;
                    String temp = "reconstr_" + cnames[index];
                    output.write(temp);
                    if (index >= lastComma) continue;
                    output.write(44);
                }
                break;
            }
            case Binomial: 
            case Multinomial: {
                String[] responseDomainValues;
                if (this.getTreePath) {
                    this.writeTreePathNames(output);
                    break;
                }
                output.write("predict");
                for (String s : responseDomainValues = this.model.getResponseDomainValues()) {
                    output.write(",");
                    output.write(s);
                }
                break;
            }
            case Ordinal: {
                String[] responseDomainValues;
                output.write("predict");
                for (String s : responseDomainValues = this.model.getResponseDomainValues()) {
                    output.write(",");
                    output.write(s);
                }
                break;
            }
            case Clustering: {
                output.write("cluster");
                break;
            }
            case Regression: {
                if (this.getTreePath) {
                    this.writeTreePathNames(output);
                    break;
                }
                output.write("predict");
                break;
            }
            case DimReduction: {
                String head;
                int datawidth;
                String[] colnames = this.model.m.getNames();
                if (this.returnGLRMReconstruct) {
                    datawidth = ((GlrmMojoModel)this.model.m)._permutation.length;
                    head = "reconstr_";
                } else if (this.model.m instanceof GlrmMojoModel) {
                    datawidth = ((GlrmMojoModel)this.model.m)._ncolX;
                    head = "Arch";
                } else {
                    datawidth = ((PCAMojoModel)this.model.m)._k;
                    head = "PC";
                }
                int lastData = datawidth - 1;
                for (int index = 0; index < datawidth; ++index) {
                    String temp = this.returnGLRMReconstruct ? head + colnames[index] : head + (index + 1);
                    output.write(temp);
                    if (index >= lastData) continue;
                    output.write(44);
                }
                break;
            }
            default: {
                throw new Exception("Unknown model category " + (Object)((Object)category));
            }
        }
        output.write("\n");
        int lineNum = 1;
        try {
            String[] inputColumnNames = null;
            String[] splitLine = reader.readNext();
            if (splitLine != null) {
                inputColumnNames = splitLine;
                this.checkMissingColumns(inputColumnNames);
            } else {
                throw new Exception("Input dataset file is empty!");
            }
            while ((splitLine = reader.readNext()) != null) {
                RowData row = PredictCsv.formatDataRow(splitLine, inputColumnNames);
                switch (category) {
                    case AutoEncoder: {
                        AbstractPrediction p = this.model.predictAutoEncoder(row);
                        for (int i = 0; i < p.reconstructed.length; ++i) {
                            output.write(this.myDoubleToString(p.reconstructed[i]));
                            if (i >= lastCommaAutoEn) continue;
                            output.write(44);
                        }
                        break;
                    }
                    case Binomial: {
                        AbstractPrediction p = this.model.predictBinomial(row);
                        if (this.getTreePath) {
                            this.writeTreePaths(((BinomialModelPrediction)p).leafNodeAssignments, output);
                            break;
                        }
                        output.write(((BinomialModelPrediction)p).label);
                        output.write(",");
                        for (int i = 0; i < ((BinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((BinomialModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Multinomial: {
                        AbstractPrediction p = this.model.predictMultinomial(row);
                        if (this.getTreePath) {
                            this.writeTreePaths(((MultinomialModelPrediction)p).leafNodeAssignments, output);
                            break;
                        }
                        output.write(((MultinomialModelPrediction)p).label);
                        output.write(",");
                        for (int i = 0; i < ((MultinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((MultinomialModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Ordinal: {
                        AbstractPrediction p = this.model.predictOrdinal(row);
                        output.write(((OrdinalModelPrediction)p).label);
                        output.write(",");
                        for (int i = 0; i < ((OrdinalModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((OrdinalModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Clustering: {
                        AbstractPrediction p = this.model.predictClustering(row);
                        output.write(this.myDoubleToString(((ClusteringModelPrediction)p).cluster));
                        break;
                    }
                    case Regression: {
                        AbstractPrediction p = this.model.predictRegression(row);
                        if (this.getTreePath) {
                            this.writeTreePaths(((RegressionModelPrediction)p).leafNodeAssignments, output);
                            break;
                        }
                        output.write(this.myDoubleToString(((RegressionModelPrediction)p).value));
                        break;
                    }
                    case DimReduction: {
                        AbstractPrediction p = this.model.predictDimReduction(row);
                        double[] out = this.returnGLRMReconstruct ? ((DimReductionModelPrediction)p).reconstructed : ((DimReductionModelPrediction)p).dimensions;
                        int lastOne = out.length - 1;
                        for (int i = 0; i < out.length; ++i) {
                            output.write(this.myDoubleToString(out[i]));
                            if (i >= lastOne) continue;
                            output.write(44);
                        }
                        break;
                    }
                    default: {
                        throw new Exception("Unknown model category " + (Object)((Object)category));
                    }
                }
                output.write("\n");
                ++lineNum;
            }
        }
        catch (Exception e) {
            System.out.println("Caught exception on line " + lineNum);
            System.out.println("");
            e.printStackTrace();
            System.exit(1);
        }
        finally {
            output.close();
            reader.close();
        }
    }

    private void writeTreePaths(String[] treePaths, BufferedWriter output) throws Exception {
        int len = treePaths.length - 1;
        for (int index = 0; index < len; ++index) {
            output.write(treePaths[index]);
            output.write(",");
        }
        output.write(treePaths[len]);
    }

    private void loadModel(String modelName) throws Exception {
        try {
            this.loadMojo(modelName);
        }
        catch (IOException e) {
            this.loadPojo(modelName);
        }
    }

    private void loadPojo(String className) throws Exception {
        GenModel genModel = (GenModel)Class.forName(className).newInstance();
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            config.setEnableLeafAssignment(true);
        }
        if (this.returnGLRMReconstruct) {
            config.setEnableGLRMReconstrut(true);
        }
        this.model = new EasyPredictModelWrapper(config);
    }

    private void loadMojo(String modelName) throws IOException {
        MojoModel genModel = MojoModel.load(modelName);
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            config.setEnableLeafAssignment(true);
        }
        if (this.returnGLRMReconstruct) {
            config.setEnableGLRMReconstrut(true);
        }
        if (this.glrmIterNumber > 0) {
            config.setGLRMIterNumber(this.glrmIterNumber);
        }
        this.model = new EasyPredictModelWrapper(config);
    }

    private static void usage() {
        System.out.println("");
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName");
        System.out.println("             --pojo pojoName --input inputFile --output outputFile --separator sepStr --decimal --setConvertInvalidNum");
        System.out.println("");
        System.out.println("     --mojo    Name of the zip file containing model's MOJO.");
        System.out.println("     --pojo    Name of the java class containing the model's POJO. Either this ");
        System.out.println("               parameter or --model must be specified.");
        System.out.println("     --input   text file containing the test data set to score.");
        System.out.println("     --output  Name of the output CSV file with computed predictions.");
        System.out.println("     --separator Separator to be used in input file containing test data set.");
        System.out.println("     --decimal Use decimal numbers in the output (default is to use hexademical).");
        System.out.println("     --setConvertInvalidNum Will call .setConvertInvalidNumbersToNa(true) when loading models.");
        System.out.println("     --leafNodeAssignment will show the leaf node assignment for GBM and DRF instead of the prediction results");
        System.out.println("     --glrmReconstruct will return the reconstructed dataset for GLRM mojo instead of X factor derived from the dataset.");
        System.out.println("     --glrmIterNumber integer indicating number of iterations to go through when constructing X factor derived from the dataset.");
        System.out.println("");
        System.exit(1);
    }

    private void checkMissingColumns(String[] parsedColumnNamesArr) {
        StringBuilder stringBuilder;
        String[] modelColumnNames = this.model.m._names;
        HashSet<String> parsedColumnNames = new HashSet<String>(parsedColumnNamesArr.length);
        for (int i = 0; i < parsedColumnNamesArr.length; ++i) {
            parsedColumnNames.add(parsedColumnNamesArr[i]);
        }
        ArrayList<String> missingColumns = new ArrayList<String>();
        for (String columnName : modelColumnNames) {
            if (!parsedColumnNames.contains(columnName) && !columnName.equals(this.model.m._responseColumn)) {
                missingColumns.add(columnName);
                continue;
            }
            parsedColumnNames.remove(columnName);
        }
        if (missingColumns.size() > 0) {
            stringBuilder = new StringBuilder("There were ");
            stringBuilder.append(missingColumns.size());
            stringBuilder.append(" missing columns found in the input data set: {");
            for (int i = 0; i < missingColumns.size(); ++i) {
                stringBuilder.append((String)missingColumns.get(i));
                if (i == missingColumns.size() - 1) continue;
                stringBuilder.append(",");
            }
            stringBuilder.append('}');
            System.out.println(stringBuilder);
        }
        if (parsedColumnNames.size() > 0) {
            stringBuilder = new StringBuilder("Detected ");
            stringBuilder.append(parsedColumnNames.size());
            stringBuilder.append(" unused columns in the input data set: {");
            Iterator iterator = parsedColumnNames.iterator();
            while (iterator.hasNext()) {
                stringBuilder.append((String)iterator.next());
                if (!iterator.hasNext()) continue;
                stringBuilder.append(",");
            }
            stringBuilder.append('}');
            System.out.println(stringBuilder);
        }
    }

    private void parseArgs(String[] args) {
        try {
            String pojoMojoModelNames = "";
            int loadType = 0;
            block25: for (int i = 0; i < args.length; ++i) {
                String s = args[i];
                if (s.equals("--header")) continue;
                if (s.equals("--decimal")) {
                    this.useDecimalOutput = true;
                    continue;
                }
                if (s.equals("--glrmReconstruct")) {
                    this.returnGLRMReconstruct = true;
                    continue;
                }
                if (s.equals("--setConvertInvalidNum")) {
                    this.setInvNumNA = true;
                    continue;
                }
                if (s.equals("--leafNodeAssignment")) {
                    this.getTreePath = true;
                    continue;
                }
                if (++i >= args.length) {
                    PredictCsv.usage();
                }
                String sarg = args[i];
                switch (s) {
                    case "--model": {
                        pojoMojoModelNames = sarg;
                        loadType = 2;
                        continue block25;
                    }
                    case "--mojo": {
                        pojoMojoModelNames = sarg;
                        loadType = 1;
                        continue block25;
                    }
                    case "--pojo": {
                        pojoMojoModelNames = sarg;
                        loadType = 0;
                        continue block25;
                    }
                    case "--input": {
                        this.inputCSVFileName = sarg;
                        continue block25;
                    }
                    case "--output": {
                        this.outputCSVFileName = sarg;
                        continue block25;
                    }
                    case "--separator": {
                        this.separator = sarg.charAt(sarg.length() - 1);
                        continue block25;
                    }
                    case "--glrmIterNumber": {
                        this.glrmIterNumber = Integer.valueOf(sarg);
                        continue block25;
                    }
                    default: {
                        System.out.println("ERROR: Unknown command line argument: " + s);
                        PredictCsv.usage();
                    }
                }
            }
            switch (loadType) {
                case 0: {
                    this.loadPojo(pojoMojoModelNames);
                    break;
                }
                case 1: {
                    this.loadMojo(pojoMojoModelNames);
                    break;
                }
                case 2: {
                    this.loadModel(pojoMojoModelNames);
                }
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            PredictCsv.usage();
        }
    }
}

