package org.apache.spark.ml.feature;

import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.feature.RFormulaBase;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: RFormula.scala */
@ScalaSignature(bytes = "\u0006\u000114A!\u0001\u0002\u0001\u001b\ti!KR8s[Vd\u0017-T8eK2T!a\u0001\u0003\u0002\u000f\u0019,\u0017\r^;sK*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001'\r\u0001a\u0002\u0006\t\u0004\u001fA\u0011R\"\u0001\u0003\n\u0005E!!!B'pI\u0016d\u0007CA\n\u0001\u001b\u0005\u0011\u0001CA\n\u0016\u0013\t1\"A\u0001\u0007S\r>\u0014X.\u001e7b\u0005\u0006\u001cX\r\u0003\u0005\u0019\u0001\t\u0015\r\u0011\"\u0011\u001a\u0003\r)\u0018\u000eZ\u000b\u00025A\u00111$\t\b\u00039}i\u0011!\b\u0006\u0002=\u0005)1oY1mC&\u0011\u0001%H\u0001\u0007!J,G-\u001a4\n\u0005\t\u001a#AB*ue&twM\u0003\u0002!;!AQ\u0005\u0001B\u0001B\u0003%!$\u0001\u0003vS\u0012\u0004\u0003\u0002C\u0014\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u0015\u0002\u001fI,7o\u001c7wK\u00124uN]7vY\u0006\u0004\"aE\u0015\n\u0005)\u0012!\u0001\u0005*fg>dg/\u001a3S\r>\u0014X.\u001e7b\u0011!a\u0003A!A!\u0002\u0013i\u0013!\u00049ja\u0016d\u0017N\\3N_\u0012,G\u000e\u0005\u0002\u0010]%\u0011q\u0006\u0002\u0002\u000e!&\u0004X\r\\5oK6{G-\u001a7\t\rE\u0002A\u0011\u0001\u00023\u0003\u0019a\u0014N\\5u}Q!!c\r\u001b6\u0011\u0015A\u0002\u00071\u0001\u001b\u0011\u00159\u0003\u00071\u0001)\u0011\u0015a\u0003\u00071\u0001.\u0011\u00159\u0004\u0001\"\u00119\u0003%!(/\u00198tM>\u0014X\u000e\u0006\u0002:\u007fA\u0011!(P\u0007\u0002w)\u0011AHB\u0001\u0004gFd\u0017B\u0001 <\u0005%!\u0015\r^1Ge\u0006lW\rC\u0003Am\u0001\u0007\u0011(A\u0004eCR\f7/\u001a;\t\u000b\t\u0003A\u0011I\"\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$\"\u0001\u0012&\u0011\u0005\u0015CU\"\u0001$\u000b\u0005\u001d[\u0014!\u0002;za\u0016\u001c\u0018BA%G\u0005)\u0019FO];diRK\b/\u001a\u0005\u0006\u0017\u0006\u0003\r\u0001R\u0001\u0007g\u000eDW-\\1\t\u000b5\u0003A\u0011\t(\u0002\t\r|\u0007/\u001f\u000b\u0003%=CQ\u0001\u0015'A\u0002E\u000bQ!\u001a=ue\u0006\u0004\"AU+\u000e\u0003MS!\u0001\u0016\u0003\u0002\u000bA\f'/Y7\n\u0005Y\u001b&\u0001\u0003)be\u0006lW*\u00199\t\u000ba\u0003A\u0011I-\u0002\u0011Q|7\u000b\u001e:j]\u001e$\u0012A\u0007\u0005\u00067\u0002!I\u0001X\u0001\u000fiJ\fgn\u001d4pe6d\u0015MY3m)\tIT\fC\u0003A5\u0002\u0007\u0011\bC\u0003`\u0001\u0011%\u0001-A\tdQ\u0016\u001c7nQ1o)J\fgn\u001d4pe6$\"!\u00193\u0011\u0005q\u0011\u0017BA2\u001e\u0005\u0011)f.\u001b;\t\u000b-s\u0006\u0019\u0001#)\u0005\u00011\u0007CA4k\u001b\u0005A'BA5\u0007\u0003)\tgN\\8uCRLwN\\\u0005\u0003W\"\u0014A\"\u0012=qKJLW.\u001a8uC2\u0004")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/RFormulaModel.class */
public class RFormulaModel extends Model<RFormulaModel> implements RFormulaBase {
    private final String uid;
    public final ResolvedRFormula org$apache$spark$ml$feature$RFormulaModel$$resolvedFormula;
    private final PipelineModel pipelineModel;
    private final Param<String> labelCol;
    private final Param<String> featuresCol;

    @Override // org.apache.spark.ml.feature.RFormulaBase
    public boolean hasLabelCol(StructType structType) {
        return RFormulaBase.Cclass.hasLabelCol(this, structType);
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    @Override // org.apache.spark.ml.Transformer
    public DataFrame transform(DataFrame dataFrame) {
        checkCanTransform(dataFrame.schema());
        return transformLabel(this.pipelineModel.transform(dataFrame));
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        boolean z;
        checkCanTransform(structType);
        StructType transformSchema = this.pipelineModel.transformSchema(structType);
        if (!hasLabelCol(transformSchema) && structType.exists(new RFormulaModel$$anonfun$transformSchema$1(this))) {
            DataType dataType = structType.apply(this.org$apache$spark$ml$feature$RFormulaModel$$resolvedFormula.label()).dataType();
            if (dataType instanceof NumericType) {
                z = true;
            } else {
                BooleanType$ booleanType$ = BooleanType$.MODULE$;
                z = booleanType$ != null ? booleanType$.equals(dataType) : dataType == null;
            }
            return new StructType((StructField[]) Predef$.MODULE$.refArrayOps(transformSchema.fields()).$colon$plus(new StructField((String) $(labelCol()), DoubleType$.MODULE$, !z, StructField$.MODULE$.apply$default$4()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
        }
        return transformSchema;
    }

    @Override // org.apache.spark.ml.Model, org.apache.spark.ml.Transformer, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public RFormulaModel copy(ParamMap paramMap) {
        return (RFormulaModel) copyValues(new RFormulaModel(uid(), this.org$apache$spark$ml$feature$RFormulaModel$$resolvedFormula, this.pipelineModel), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.PipelineStage, org.apache.spark.ml.util.Identifiable
    public String toString() {
        return new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"RFormulaModel(", ") (uid=", DefaultExpressionEngine.DEFAULT_INDEX_END})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.org$apache$spark$ml$feature$RFormulaModel$$resolvedFormula, uid()}));
    }

    private DataFrame transformLabel(DataFrame dataFrame) {
        boolean z;
        String label = this.org$apache$spark$ml$feature$RFormulaModel$$resolvedFormula.label();
        if (!hasLabelCol(dataFrame.schema()) && dataFrame.schema().exists(new RFormulaModel$$anonfun$transformLabel$1(this, label))) {
            DataType dataType = dataFrame.schema().apply(label).dataType();
            if (dataType instanceof NumericType) {
                z = true;
            } else {
                BooleanType$ booleanType$ = BooleanType$.MODULE$;
                z = booleanType$ != null ? booleanType$.equals(dataType) : dataType == null;
            }
            if (z) {
                return dataFrame.withColumn((String) $(labelCol()), dataFrame.apply(label).cast(DoubleType$.MODULE$));
            }
            throw new IllegalArgumentException(new StringBuilder().append((Object) "Unsupported type for label: ").append(dataType).toString());
        }
        return dataFrame;
    }

    private void checkCanTransform(StructType structType) {
        boolean z;
        Seq seq = (Seq) structType.map(new RFormulaModel$$anonfun$3(this), Seq$.MODULE$.canBuildFrom());
        Predef$.MODULE$.require(!seq.contains($(featuresCol())), new RFormulaModel$$anonfun$checkCanTransform$1(this));
        Predef$ predef$ = Predef$.MODULE$;
        if (seq.contains($(labelCol()))) {
            DataType dataType = structType.apply((String) $(labelCol())).dataType();
            DoubleType$ doubleType$ = DoubleType$.MODULE$;
            if (dataType != null ? !dataType.equals(doubleType$) : doubleType$ != null) {
                z = false;
                predef$.require(z, new RFormulaModel$$anonfun$checkCanTransform$2(this));
            }
        }
        z = true;
        predef$.require(z, new RFormulaModel$$anonfun$checkCanTransform$2(this));
    }

    public RFormulaModel(String str, ResolvedRFormula resolvedRFormula, PipelineModel pipelineModel) {
        this.uid = str;
        this.org$apache$spark$ml$feature$RFormulaModel$$resolvedFormula = resolvedRFormula;
        this.pipelineModel = pipelineModel;
        HasFeaturesCol.Cclass.$init$(this);
        HasLabelCol.Cclass.$init$(this);
        RFormulaBase.Cclass.$init$(this);
    }
}
