package org.apache.spark.examples.ml;

import org.apache.cassandra.config.CFMetaData;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

/* loaded from: input_file:org/apache/spark/examples/ml/JavaTrainValidationSplitExample.class */
public class JavaTrainValidationSplitExample {
    public static void main(String[] strArr) {
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaTrainValidationSplitExample"));
        DataFrame[] randomSplit = new SQLContext(javaSparkContext).createDataFrame(MLUtils.loadLibSVMFile(javaSparkContext.sc(), "data/mllib/sample_libsvm_data.txt"), LabeledPoint.class).randomSplit(new double[]{0.9d, 0.1d}, 12345L);
        DataFrame dataFrame = randomSplit[0];
        DataFrame dataFrame2 = randomSplit[1];
        LinearRegression linearRegression = new LinearRegression();
        TrainValidationSplit estimatorParamMaps = new TrainValidationSplit().setEstimator(linearRegression).setEvaluator(new RegressionEvaluator()).setEstimatorParamMaps(new ParamGridBuilder().addGrid(linearRegression.regParam(), new double[]{0.1d, 0.01d}).addGrid(linearRegression.fitIntercept()).addGrid(linearRegression.elasticNetParam(), new double[]{CFMetaData.DEFAULT_DCLOCAL_READ_REPAIR_CHANCE, 0.5d, 1.0d}).build());
        estimatorParamMaps.setTrainRatio(0.8d);
        estimatorParamMaps.fit(dataFrame).transform(dataFrame2).select("features", new String[]{"label", "prediction"}).show();
        javaSparkContext.stop();
    }
}
