package org.apache.spark.examples.ml;

import org.apache.hadoop.hive.serde2.avro.AvroSerDe;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.SeqLike;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxesRunTime;

/* compiled from: DecisionTreeClassificationExample.scala */
/* loaded from: input_file:org/apache/spark/examples/ml/DecisionTreeClassificationExample$.class */
public final class DecisionTreeClassificationExample$ {
    public static final DecisionTreeClassificationExample$ MODULE$ = null;

    static {
        new DecisionTreeClassificationExample$();
    }

    public void main(String[] strArr) {
        DataFrame load = new SQLContext(new SparkContext(new SparkConf().setAppName("DecisionTreeClassificationExample"))).read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
        PipelineStage fit = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(load);
        PipelineStage fit2 = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(load);
        DataFrame[] randomSplit = load.randomSplit(new double[]{0.7d, 0.3d});
        Option unapplySeq = Array$.MODULE$.unapplySeq(randomSplit);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(2) != 0) {
            throw new MatchError(randomSplit);
        }
        Tuple2 tuple2 = new Tuple2((DataFrame) ((SeqLike) unapplySeq.get()).apply(0), (DataFrame) ((SeqLike) unapplySeq.get()).apply(1));
        DataFrame dataFrame = (DataFrame) tuple2._1();
        DataFrame dataFrame2 = (DataFrame) tuple2._2();
        PipelineModel fit3 = new Pipeline().setStages(new PipelineStage[]{fit, fit2, (DecisionTreeClassifier) new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures"), new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(fit.labels())}).fit(dataFrame);
        DataFrame transform = fit3.transform(dataFrame2);
        transform.select("predictedLabel", Predef$.MODULE$.wrapRefArray(new String[]{"label", "features"})).show(5);
        Predef$.MODULE$.println(new StringBuilder().append("Test Error = ").append(BoxesRunTime.boxToDouble(1.0d - new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName(AvroSerDe.AVRO_PROP_PRECISION).evaluate(transform))).toString());
        Predef$.MODULE$.println(new StringBuilder().append("Learned classification tree model:\n").append(fit3.stages()[2].toDebugString()).toString());
    }

    private DecisionTreeClassificationExample$() {
        MODULE$ = this;
    }
}
