package org.apache.spark.mllib.classification;

import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.optimization.LBFGS;
import org.apache.spark.mllib.optimization.LogisticGradient;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Function1;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;

/* compiled from: LogisticRegression.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u001da\u0001B\u0001\u0003\u00015\u00111\u0004T8hSN$\u0018n\u0019*fOJ,7o]5p]^KG\u000f\u001b'C\r\u001e\u001b&BA\u0002\u0005\u00039\u0019G.Y:tS\u001aL7-\u0019;j_:T!!\u0002\u0004\u0002\u000b5dG.\u001b2\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001aE\u0002\u0001\u001da\u00012a\u0004\n\u0015\u001b\u0005\u0001\"BA\t\u0005\u0003)\u0011Xm\u001a:fgNLwN\\\u0005\u0003'A\u0011!dR3oKJ\fG.\u001b>fI2Kg.Z1s\u00032<wN]5uQ6\u0004\"!\u0006\f\u000e\u0003\tI!a\u0006\u0002\u0003/1{w-[:uS\u000e\u0014Vm\u001a:fgNLwN\\'pI\u0016d\u0007CA\r\u001d\u001b\u0005Q\"\"A\u000e\u0002\u000bM\u001c\u0017\r\\1\n\u0005uQ\"\u0001D*fe&\fG.\u001b>bE2,\u0007\"B\u0010\u0001\t\u0003\u0001\u0013A\u0002\u001fj]&$h\bF\u0001\"!\t)\u0002\u0001C\u0004$\u0001\t\u0007I\u0011\t\u0013\u0002\u0013=\u0004H/[7ju\u0016\u0014X#A\u0013\u0011\u0005\u0019JS\"A\u0014\u000b\u0005!\"\u0011\u0001D8qi&l\u0017N_1uS>t\u0017B\u0001\u0016(\u0005\u0015a%IR$TQ\r\u0011CF\r\t\u0003[Aj\u0011A\f\u0006\u0003_\u0019\t!\"\u00198o_R\fG/[8o\u0013\t\tdFA\u0003TS:\u001cW-I\u00014\u0003\u0015\td&\r\u00181\u0011\u0019)\u0004\u0001)A\u0005K\u0005Qq\u000e\u001d;j[&TXM\u001d\u0011)\u0007Qb#\u0007C\u00049\u0001\t\u0007I\u0011K\u001d\u0002\u0015Y\fG.\u001b3bi>\u00148/F\u0001;!\rY\u0004IQ\u0007\u0002y)\u0011QHP\u0001\nS6lW\u000f^1cY\u0016T!a\u0010\u000e\u0002\u0015\r|G\u000e\\3di&|g.\u0003\u0002By\t!A*[:u!\u0011I2)\u0012(\n\u0005\u0011S\"!\u0003$v]\u000e$\u0018n\u001c82!\r1\u0015jS\u0007\u0002\u000f*\u0011\u0001JB\u0001\u0004e\u0012$\u0017B\u0001&H\u0005\r\u0011F\t\u0012\t\u0003\u001f1K!!\u0014\t\u0003\u00191\u000b'-\u001a7fIB{\u0017N\u001c;\u0011\u0005ey\u0015B\u0001)\u001b\u0005\u001d\u0011un\u001c7fC:DaA\u0015\u0001!\u0002\u0013Q\u0014a\u0003<bY&$\u0017\r^8sg\u0002BQ\u0001\u0016\u0001\u0005\nU\u000b1#\\;mi&d\u0015MY3m-\u0006d\u0017\u000eZ1u_J,\u0012A\u0011\u0005\u0006/\u0002!\t\u0001W\u0001\u000eg\u0016$h*^7DY\u0006\u001c8/Z:\u0015\u0005eSV\"\u0001\u0001\t\u000bm3\u0006\u0019\u0001/\u0002\u00159,Xn\u00117bgN,7\u000f\u0005\u0002\u001a;&\u0011aL\u0007\u0002\u0004\u0013:$\bf\u0001,-A\u0006\n\u0011-A\u00032]Mr\u0003\u0007C\u0003d\u0001\u0011EC-A\u0006de\u0016\fG/Z'pI\u0016dGc\u0001\u000bf[\")aM\u0019a\u0001O\u00069q/Z5hQR\u001c\bC\u00015l\u001b\u0005I'B\u00016\u0005\u0003\u0019a\u0017N\\1mO&\u0011A.\u001b\u0002\u0007-\u0016\u001cGo\u001c:\t\u000b9\u0014\u0007\u0019A8\u0002\u0013%tG/\u001a:dKB$\bCA\rq\u0013\t\t(D\u0001\u0004E_V\u0014G.\u001a\u0005\u0006g\u0002!\t\u0005^\u0001\u0004eVtGC\u0001\u000bv\u0011\u00151(\u000f1\u0001F\u0003\u0015Ig\u000e];u\u0011\u0015\u0019\b\u0001\"\u0011y)\r!\u0012P\u001f\u0005\u0006m^\u0004\r!\u0012\u0005\u0006w^\u0004\raZ\u0001\u000fS:LG/[1m/\u0016Lw\r\u001b;t\u0011\u0015\u0019\b\u0001\"\u0003~)\u0015!bp`A\u0001\u0011\u00151H\u00101\u0001F\u0011\u0015YH\u00101\u0001h\u0011\u0019\t\u0019\u0001 a\u0001\u001d\u0006\u0019Ro]3s'V\u0004\b\u000f\\5fI^+\u0017n\u001a5ug\"\u001a\u0001\u0001\f\u001a")
/* loaded from: input_file:org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.class */
public class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm<LogisticRegressionModel> {
    private final LBFGS optimizer;
    private final List<Function1<RDD<LabeledPoint>, Object>> validators;

    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    public LBFGS optimizer() {
        return this.optimizer;
    }

    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    /* renamed from: validators, reason: merged with bridge method [inline-methods] */
    public List<Function1<RDD<LabeledPoint>, Object>> mo706validators() {
        return this.validators;
    }

    private Function1<RDD<LabeledPoint>, Object> multiLabelValidator() {
        return new LogisticRegressionWithLBFGS$$anonfun$multiLabelValidator$1(this);
    }

    public LogisticRegressionWithLBFGS setNumClasses(int i) {
        Predef$.MODULE$.require(i > 1);
        numOfLinearPredictor_$eq(i - 1);
        if (i > 2) {
            optimizer().setGradient(new LogisticGradient(i));
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    public LogisticRegressionModel createModel(Vector vector, double d) {
        return numOfLinearPredictor() == 1 ? new LogisticRegressionModel(vector, d) : new LogisticRegressionModel(vector, d, numFeatures(), numOfLinearPredictor() + 1);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    public LogisticRegressionModel run(RDD<LabeledPoint> rdd) {
        return run(rdd, generateInitialWeights(rdd), false);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    public LogisticRegressionModel run(RDD<LabeledPoint> rdd, Vector vector) {
        return run(rdd, vector, true);
    }

    private LogisticRegressionModel run(RDD<LabeledPoint> rdd, Vector vector, boolean z) {
        if (numOfLinearPredictor() != 1) {
            return (LogisticRegressionModel) super.run(rdd, vector);
        }
        Updater updater = optimizer().getUpdater();
        return updater instanceof SquaredL2Updater ? runWithMlLogisticRegression$1(0.0d, rdd, vector, z) : updater instanceof L1Updater ? runWithMlLogisticRegression$1(1.0d, rdd, vector, z) : (LogisticRegressionModel) super.run(rdd, vector);
    }

    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    public /* bridge */ /* synthetic */ LogisticRegressionModel run(RDD rdd, Vector vector) {
        return run((RDD<LabeledPoint>) rdd, vector);
    }

    @Override // org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm
    public /* bridge */ /* synthetic */ LogisticRegressionModel run(RDD rdd) {
        return run((RDD<LabeledPoint>) rdd);
    }

    private final LogisticRegressionModel runWithMlLogisticRegression$1(double d, RDD rdd, Vector vector, boolean z) {
        LogisticRegression logisticRegression = new LogisticRegression();
        logisticRegression.setRegParam(optimizer().getRegParam());
        logisticRegression.setElasticNetParam(d);
        logisticRegression.setStandardization(useFeatureScaling());
        if (z) {
            logisticRegression.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel(Identifiable$.MODULE$.randomUID("logreg-static"), new DenseMatrix(1, vector.size(), vector.toArray()), Vectors$.MODULE$.dense(1.0d, (Seq<Object>) Predef$.MODULE$.wrapDoubleArray(new double[0])).mo1007asML(), 2, false));
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        logisticRegression.setFitIntercept(addIntercept());
        logisticRegression.setMaxIter(optimizer().getNumIterations());
        logisticRegression.setTol(optimizer().getConvergenceTol());
        Dataset<?> createDataFrame = SparkSession$.MODULE$.builder().sparkContext(rdd.context()).getOrCreate().createDataFrame(rdd.map(new LogisticRegressionWithLBFGS$$anonfun$3(this), ClassTag$.MODULE$.apply(org.apache.spark.ml.feature.LabeledPoint.class)), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(LogisticRegressionWithLBFGS.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.feature.LabeledPoint").asType().toTypeConstructor();
            }
        }));
        StorageLevel storageLevel = rdd.getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        org.apache.spark.ml.classification.LogisticRegressionModel train = logisticRegression.train(createDataFrame, storageLevel != null ? storageLevel.equals(NONE) : NONE == null);
        return createModel(Vectors$.MODULE$.dense(train.coefficients().toArray()), train.intercept());
    }

    public LogisticRegressionWithLBFGS() {
        setFeatureScaling(true);
        this.optimizer = new LBFGS(new LogisticGradient(), new SquaredL2Updater());
        this.validators = List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Function1[]{multiLabelValidator()}));
    }
}
