/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.stacking;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
import org.apache.ignite.ml.composition.stacking.StackedModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.AdaptableDatasetModel;
import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;

public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
extends DatasetTrainer<StackedModel<IS, IA, O, AM>, L> {
    private IgniteBinaryOperator<IA> aggregatingInputMerger;
    private IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter;
    private List<DatasetTrainer<IgniteModel<IS, IA>, L>> submodelsTrainers;
    private DatasetTrainer<AM, L> aggregatorTrainer;
    private IgniteFunction<Vector, IS> vector2SubmodelInputConverter;
    private IgniteFunction<IA, Vector> submodelOutput2VectorConverter;

    public StackedDatasetTrainer(DatasetTrainer<AM, L> aggregatorTrainer, IgniteBinaryOperator<IA> aggregatingInputMerger, IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter, List<DatasetTrainer<IgniteModel<IS, IA>, L>> submodelsTrainers, IgniteFunction<Vector, IS> vector2SubmodelInputConverter, IgniteFunction<IA, Vector> submodelOutput2VectorConverter) {
        this.aggregatorTrainer = aggregatorTrainer;
        this.aggregatingInputMerger = aggregatingInputMerger;
        this.submodelInput2AggregatingInputConverter = submodelInput2AggregatingInputConverter;
        this.submodelsTrainers = new ArrayList<DatasetTrainer<IgniteModel<IS, IA>, L>>(submodelsTrainers);
        this.vector2SubmodelInputConverter = vector2SubmodelInputConverter;
        this.submodelOutput2VectorConverter = submodelOutput2VectorConverter;
    }

    public StackedDatasetTrainer(DatasetTrainer<AM, L> aggregatorTrainer, IgniteBinaryOperator<IA> aggregatingInputMerger, IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter) {
        this(aggregatorTrainer, aggregatingInputMerger, submodelInput2AggregatingInputConverter, new ArrayList<DatasetTrainer<IgniteModel<IS, IA>, L>>(), null, null);
    }

    public StackedDatasetTrainer() {
        this(null, null, null, new ArrayList<DatasetTrainer<IgniteModel<IS, IA>, L>>(), null, null);
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withOriginalFeaturesKept(IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter) {
        this.submodelInput2AggregatingInputConverter = submodelInput2AggregatingInputConverter;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withOriginalFeaturesDropped() {
        this.submodelInput2AggregatingInputConverter = null;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withSubmodelOutput2VectorConverter(IgniteFunction<IA, Vector> submodelOutput2VectorConverter) {
        this.submodelOutput2VectorConverter = submodelOutput2VectorConverter;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withVector2SubmodelInputConverter(IgniteFunction<Vector, IS> vector2SubmodelInputConverter) {
        this.vector2SubmodelInputConverter = vector2SubmodelInputConverter;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withAggregatorTrainer(DatasetTrainer<AM, L> aggregatorTrainer) {
        this.aggregatorTrainer = aggregatorTrainer;
        return this;
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withAggregatorInputMerger(IgniteBinaryOperator<IA> merger) {
        this.aggregatingInputMerger = merger;
        return this;
    }

    public <M1 extends IgniteModel<IS, IA>> StackedDatasetTrainer<IS, IA, O, AM, L> addTrainer(DatasetTrainer<M1, L> trainer) {
        this.submodelsTrainers.add(CompositionUtils.unsafeCoerce(trainer));
        return this;
    }

    @Override
    public <K, V> StackedModel<IS, IA, O, AM> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return new StackedModel(this.getTrainer().fit(datasetBuilder, preprocessor));
    }

    @Override
    public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        this.learningEnvironment().initDeployingContext(preprocessor);
        return new StackedModel(this.getTrainer().update(mdl, datasetBuilder, preprocessor));
    }

    private DatasetTrainer<IgniteModel<IS, O>, L> getTrainer() {
        this.checkConsistency();
        ArrayList subs = new ArrayList();
        if (this.submodelInput2AggregatingInputConverter != null) {
            DatasetTrainer id = DatasetTrainer.identityTrainer();
            DatasetTrainer mappedId = CompositionUtils.unsafeCoerce(AdaptableDatasetTrainer.of(id).afterTrainedModel(this.submodelInput2AggregatingInputConverter));
            subs.add(mappedId);
        }
        subs.addAll(this.submodelsTrainers);
        TrainersParallelComposition composition = new TrainersParallelComposition(subs);
        IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> featureMapper = StackedDatasetTrainer.getFeatureExtractorForAggregator(this.submodelOutput2VectorConverter, this.vector2SubmodelInputConverter);
        return AdaptableDatasetTrainer.of(composition).afterTrainedModel(lst -> lst.stream().reduce(this.aggregatingInputMerger).get()).andThen(this.aggregatorTrainer, model -> new IgniteFunction<LabeledVector<L>, LabeledVector<L>>((AdaptableDatasetModel)model, featureMapper){
            final /* synthetic */ AdaptableDatasetModel val$model;
            final /* synthetic */ IgniteBiFunction val$featureMapper;
            {
                this.val$model = adaptableDatasetModel;
                this.val$featureMapper = igniteBiFunction;
            }

            @Override
            public LabeledVector<L> apply(LabeledVector<L> v) {
                List models = ((ModelsParallelComposition)this.val$model.innerModel()).submodels();
                return new LabeledVector((Vector)this.val$featureMapper.apply(models, v.features()), v.label());
            }
        }).unsafeSimplyTyped();
    }

    private void checkConsistency() {
        if (this.submodelInput2AggregatingInputConverter == null && this.submodelsTrainers.isEmpty()) {
            throw new IllegalStateException("There should be at least one way for submodels input to be propageted to aggregator.");
        }
        if (this.submodelOutput2VectorConverter == null || this.vector2SubmodelInputConverter == null) {
            throw new IllegalStateException("There should be a specified way to convert vectors to submodels input and submodels output to vector");
        }
        if (this.aggregatingInputMerger == null) {
            throw new IllegalStateException("Binary operator used to convert outputs of submodels is not specified");
        }
    }

    public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        this.submodelsTrainers = this.submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
        this.aggregatorTrainer = this.aggregatorTrainer.withEnvironmentBuilder(envBuilder);
        return this;
    }

    private static <IS, IA, K, V> IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> getFeatureExtractorForAggregator(IgniteFunction<IA, Vector> submodelOutput2VectorConverter, IgniteFunction<Vector, IS> vector2SubmodelInputConverter) {
        return (subMdls, v) -> {
            Vector[] vs = (Vector[])subMdls.stream().map(sm -> StackedDatasetTrainer.applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
            return VectorUtils.concat(vs);
        };
    }

    private static <IS, IA> Vector applyToVector(IgniteModel<IS, IA> mdl, IgniteFunction<IA, Vector> submodelOutput2VectorConverter, IgniteFunction<Vector, IS> vector2SubmodelInputConverter, Vector v) {
        return (Vector)vector2SubmodelInputConverter.andThen(mdl::predict).andThen(submodelOutput2VectorConverter).apply(v);
    }

    @Override
    protected <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        throw new IllegalStateException();
    }

    @Override
    public boolean isUpdateable(StackedModel<IS, IA, O, AM> mdl) {
        throw new IllegalStateException();
    }
}

