/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.regressions.linear;

import java.util.Arrays;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap;
import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class LinearRegressionLSQRTrainer
extends SingleLabelDatasetTrainer<LinearRegressionModel> {
    @Override
    public <K, V> LinearRegressionModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((LinearRegressionModel)null, datasetBuilder, extractor);
    }

    private static LabeledVector<double[]> extendLabeledVector(LabeledVector<Double> lb) {
        double[] featuresArr = new double[lb.features().size() + 1];
        System.arraycopy(lb.features().asArray(), 0, featuresArr, 0, lb.features().size());
        featuresArr[featuresArr.length - 1] = 1.0;
        Vector features = VectorUtils.of(featuresArr);
        double[] lbl = new double[]{lb.label()};
        return features.labeled(lbl);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        LSQRResult res;
        PatchedPreprocessor patchedPreprocessor = new PatchedPreprocessor(LinearRegressionLSQRTrainer::extendLabeledVector, extractor);
        try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<K, V>(datasetBuilder, this.envBuilder, new SimpleLabeledDatasetDataBuilder(patchedPreprocessor), this.learningEnvironment());){
            double[] x0 = null;
            if (mdl != null) {
                int x0Size = mdl.getWeights().size() + 1;
                Vector weights = mdl.getWeights().like(x0Size);
                mdl.getWeights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
                weights.set(weights.size() - 1, mdl.getIntercept());
                x0 = weights.asArray();
            }
            if ((res = lsqr.solve(0.0, 1.0E-12, 1.0E-12, 1.0E8, -1.0, false, x0)) == null) {
                LinearRegressionModel linearRegressionModel = this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
                return linearRegressionModel;
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        double[] x = res.getX();
        DenseVector weights = new DenseVector(Arrays.copyOfRange(x, 0, x.length - 1));
        return new LinearRegressionModel(weights, x[x.length - 1]);
    }

    @Override
    public boolean isUpdateable(LinearRegressionModel mdl) {
        return true;
    }
}

