package org.apache.spark.examples.ml;

import org.apache.hadoop.hive.serde2.avro.AvroSerDe;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaRandomForestClassifierExample.class */
public class JavaRandomForestClassifierExample {
    public static void main(String[] strArr) {
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaRandomForestClassifierExample"));
        DataFrame load = new SQLContext(javaSparkContext).read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
        PipelineStage fit = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(load);
        PipelineStage fit2 = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(load);
        DataFrame[] randomSplit = load.randomSplit(new double[]{0.7d, 0.3d});
        DataFrame dataFrame = randomSplit[0];
        DataFrame dataFrame2 = randomSplit[1];
        PipelineModel fit3 = new Pipeline().setStages(new PipelineStage[]{fit, fit2, (RandomForestClassifier) new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures"), new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(fit.labels())}).fit(dataFrame);
        DataFrame transform = fit3.transform(dataFrame2);
        transform.select("predictedLabel", new String[]{"label", "features"}).show(5);
        System.out.println("Test Error = " + (1.0d - new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName(AvroSerDe.AVRO_PROP_PRECISION).evaluate(transform)));
        System.out.println("Learned classification forest model:\n" + fit3.stages()[2].toDebugString());
        javaSparkContext.stop();
    }
}
