/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.math.isolve.lsqr;

import com.github.fommil.netlib.BLAS;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR;
import org.apache.ignite.ml.math.isolve.lsqr.LSQRPartitionContext;

public class LSQROnHeap<K, V>
extends AbstractLSQR
implements AutoCloseable {
    private final Dataset<LSQRPartitionContext, SimpleLabeledDatasetData> dataset;

    public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder, PartitionDataBuilder<K, V, LSQRPartitionContext, SimpleLabeledDatasetData> partDataBuilder) {
        this.dataset = datasetBuilder.build((upstream, upstreamSize) -> new LSQRPartitionContext(), partDataBuilder);
    }

    @Override
    protected double bnorm() {
        return (Double)this.dataset.computeWithCtx((ctx, data) -> {
            ctx.setU(Arrays.copyOf(data.getLabels(), data.getLabels().length));
            return BLAS.getInstance().dnrm2(data.getLabels().length, data.getLabels(), 1);
        }, (a, b) -> a == null ? b : (b == null ? a : Math.sqrt(a * a + b * b)));
    }

    @Override
    protected double beta(double[] x, double alfa, double beta) {
        return (Double)this.dataset.computeWithCtx((ctx, data) -> {
            if (data.getFeatures() == null) {
                return null;
            }
            int cols = data.getFeatures().length / data.getRows();
            BLAS.getInstance().dgemv("N", data.getRows(), cols, alfa, data.getFeatures(), Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1);
            return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1);
        }, (a, b) -> a == null ? b : (b == null ? a : Math.sqrt(a * a + b * b)));
    }

    @Override
    protected double[] iter(double bnorm, double[] target) {
        double[] res = (double[])this.dataset.computeWithCtx((ctx, data) -> {
            if (data.getFeatures() == null) {
                return null;
            }
            int cols = data.getFeatures().length / data.getRows();
            BLAS.getInstance().dscal(ctx.getU().length, 1.0 / bnorm, ctx.getU(), 1);
            double[] v = new double[cols];
            BLAS.getInstance().dgemv("T", data.getRows(), cols, 1.0, data.getFeatures(), Math.max(1, data.getRows()), ctx.getU(), 1, 0.0, v, 1);
            return v;
        }, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            BLAS.getInstance().daxpy(((double[])a).length, 1.0, a, 1, b, 1);
            return b;
        });
        BLAS.getInstance().daxpy(res.length, 1.0, res, 1, target, 1);
        return target;
    }

    @Override
    protected Integer getColumns() {
        return (Integer)this.dataset.compute(data -> data.getFeatures() == null ? null : Integer.valueOf(data.getFeatures().length / data.getRows()), (a, b) -> {
            if (a == null) {
                return b == null ? 0 : b;
            }
            if (b == null) {
                return a;
            }
            return b;
        });
    }

    @Override
    public void close() throws Exception {
        this.dataset.close();
    }
}

