/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.optimization.updatecalculators;

import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;

public class SimpleGDUpdateCalculator
implements ParameterUpdateCalculator<SmoothParametrized, SimpleGDParameterUpdate> {
    private static final long serialVersionUID = -4237332083320879334L;
    private double learningRate;
    protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    private static final double DEFAULT_LEARNING_RATE = 0.1;

    public SimpleGDUpdateCalculator() {
        this(0.1);
    }

    public SimpleGDUpdateCalculator(double learningRate) {
        this.learningRate = learningRate;
    }

    @Override
    public SimpleGDParameterUpdate init(SmoothParametrized mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
        this.loss = loss;
        return new SimpleGDParameterUpdate(mdl.parametersCount());
    }

    @Override
    public SimpleGDParameterUpdate calculateNewUpdate(SmoothParametrized mlp, SimpleGDParameterUpdate updaterParameters, int iteration, Matrix inputs, Matrix groundTruth) {
        return new SimpleGDParameterUpdate(mlp.differentiateByParameters(this.loss, inputs, groundTruth));
    }

    @Override
    public <M1 extends SmoothParametrized> M1 update(M1 obj, SimpleGDParameterUpdate update) {
        Vector params = obj.parameters();
        return (M1)((SmoothParametrized)obj.setParameters(params.minus(update.gradient().times(this.learningRate))));
    }

    public SimpleGDUpdateCalculator withLearningRate(double learningRate) {
        return new SimpleGDUpdateCalculator(learningRate);
    }
}

