/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.optimization.updatecalculators;

import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.VectorUtils;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.util.MatrixUtil;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;

public class RPropUpdateCalculator
implements ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate> {
    private static final long serialVersionUID = -5156816330041409864L;
    private static double DFLT_INIT_UPDATE = 0.1;
    private static double DFLT_ACCELERATION_RATE = 1.2;
    private static double DFLT_DEACCELERATION_RATE = 0.5;
    private final double initUpdate;
    private final double accelerationRate;
    private final double deaccelerationRate;
    private static final double UPDATE_MAX = 50.0;
    private static final double UPDATE_MIN = 1.0E-6;
    protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;

    public RPropUpdateCalculator(double initUpdate, double accelerationRate, double deaccelerationRate) {
        this.initUpdate = initUpdate;
        this.accelerationRate = accelerationRate;
        this.deaccelerationRate = deaccelerationRate;
    }

    public RPropUpdateCalculator() {
        this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE);
    }

    @Override
    public RPropParameterUpdate calculateNewUpdate(SmoothParametrized mdl, RPropParameterUpdate updaterParams, int iteration, Matrix inputs, Matrix groundTruth) {
        Vector gradient = mdl.differentiateByParameters(this.loss, inputs, groundTruth);
        Vector prevGradient = updaterParams.prevIterationGradient();
        Vector derSigns = prevGradient != null ? VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y)) : gradient.like(gradient.size()).assign(1.0);
        Vector newDeltas = updaterParams.deltas().copy().map(derSigns, (prevDelta, sign) -> {
            if (sign > 0.0) {
                return Math.min(prevDelta * this.accelerationRate, 50.0);
            }
            if (sign < 0.0) {
                return Math.max(prevDelta * this.deaccelerationRate, 1.0E-6);
            }
            return prevDelta;
        });
        Vector newPrevIterationUpdates = MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> {
            if (derSigns.getX((int)i) >= 0.0) {
                return -Math.signum(der) * delta;
            }
            return updaterParams.prevIterationUpdates().getX((int)i);
        });
        Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> {
            if (sign < 0.0) {
                gradient.setX((int)i, 0.0);
            }
            if (sign >= 0.0) {
                return 1.0;
            }
            return -1.0;
        });
        return new RPropParameterUpdate(newPrevIterationUpdates, gradient.copy(), newDeltas, updatesMask);
    }

    @Override
    public RPropParameterUpdate init(SmoothParametrized mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
        this.loss = loss;
        return new RPropParameterUpdate(mdl.parametersCount(), this.initUpdate);
    }

    @Override
    public <M1 extends SmoothParametrized> M1 update(M1 obj, RPropParameterUpdate update) {
        Vector updatesToAdd = VectorUtils.elementWiseTimes(update.updatesMask().copy(), update.prevIterationUpdates());
        return (M1)((SmoothParametrized)obj.setParameters(obj.parameters().plus(updatesToAdd)));
    }
}

