package org.apache.spark.examples.ml;

import org.apache.cassandra.config.CFMetaData;
import org.apache.hadoop.hbase.util.Strings;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.spark-project.guava.collect.Lists;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.class */
public class JavaSimpleTextClassificationPipeline {
    public static void main(String[] strArr) {
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"));
        SQLContext sQLContext = new SQLContext(javaSparkContext);
        DataFrame createDataFrame = sQLContext.createDataFrame(javaSparkContext.parallelize(Lists.newArrayList(new LabeledDocument[]{new LabeledDocument(0L, "a b c d e spark", 1.0d), new LabeledDocument(1L, "b d", CFMetaData.DEFAULT_DCLOCAL_READ_REPAIR_CHANCE), new LabeledDocument(2L, "spark f g h", 1.0d), new LabeledDocument(3L, "hadoop mapreduce", CFMetaData.DEFAULT_DCLOCAL_READ_REPAIR_CHANCE)})), LabeledDocument.class);
        PipelineStage pipelineStage = (Tokenizer) new Tokenizer().setInputCol("text").setOutputCol("words");
        for (Row row : new Pipeline().setStages(new PipelineStage[]{pipelineStage, new HashingTF().setNumFeatures(1000).setInputCol(pipelineStage.getOutputCol()).setOutputCol("features"), new LogisticRegression().setMaxIter(10).setRegParam(0.001d)}).fit(createDataFrame).transform(sQLContext.createDataFrame(javaSparkContext.parallelize(Lists.newArrayList(new Document[]{new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")})), Document.class)).select("id", new String[]{"text", "probability", "prediction"}).collect()) {
            System.out.println("(" + row.get(0) + Strings.DEFAULT_KEYVALUE_SEPARATOR + row.get(1) + ") --> prob=" + row.get(2) + ", prediction=" + row.get(3));
        }
        javaSparkContext.stop();
    }
}
