/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.optimization;

import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.optimization.Updater;

public class BarzilaiBorweinUpdater
implements Updater {
    private static final long serialVersionUID = 5046575099408708472L;
    private static final double INITIAL_LEARNING_RATE = 1.0;

    @Override
    public Vector compute(Vector oldWeights, Vector oldGradient, Vector weights, Vector gradient, int iteration) {
        double learningRate = this.computeLearningRate(oldWeights != null ? oldWeights.copy() : null, oldGradient != null ? oldGradient.copy() : null, weights.copy(), gradient.copy());
        return weights.copy().minus(gradient.copy().times(learningRate));
    }

    private double computeLearningRate(Vector oldWeights, Vector oldGradient, Vector weights, Vector gradient) {
        if (oldWeights == null || oldGradient == null) {
            return 1.0;
        }
        Vector gradientDiff = gradient.minus(oldGradient);
        return weights.minus(oldWeights).dot(gradientDiff) / Math.pow(gradientDiff.kNorm(2.0), 2.0);
    }
}

