/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.optimization.updatecalculators;

import java.io.Serializable;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;

public class RPropParameterUpdate
implements Serializable {
    public static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate> SUM = RPropParameterUpdate::sum;
    public static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate> AVG = RPropParameterUpdate::avg;
    public static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate> SUM_LOCAL = RPropParameterUpdate::sumLocal;
    private static final long serialVersionUID = -165584242642323332L;
    protected Vector prevIterationUpdates;
    protected Vector prevIterationGradient;
    protected Vector deltas;
    protected Vector updatesMask;

    RPropParameterUpdate(int paramsCnt, double initUpdate) {
        this.prevIterationUpdates = new DenseVector(paramsCnt);
        this.prevIterationGradient = new DenseVector(paramsCnt);
        this.deltas = new DenseVector(paramsCnt).assign(initUpdate);
        this.updatesMask = new DenseVector(paramsCnt);
    }

    public RPropParameterUpdate(Vector prevIterationUpdates, Vector prevIterationGradient, Vector deltas, Vector updatesMask) {
        this.prevIterationUpdates = prevIterationUpdates;
        this.prevIterationGradient = prevIterationGradient;
        this.deltas = deltas;
        this.updatesMask = updatesMask;
    }

    Vector deltas() {
        return this.deltas;
    }

    Vector prevIterationUpdates() {
        return this.prevIterationUpdates;
    }

    private RPropParameterUpdate setPrevIterationUpdates(Vector updates) {
        this.prevIterationUpdates = updates;
        return this;
    }

    Vector prevIterationGradient() {
        return this.prevIterationGradient;
    }

    private RPropParameterUpdate setPrevIterationGradient(Vector gradient) {
        this.prevIterationGradient = gradient;
        return this;
    }

    public Vector updatesMask() {
        return this.updatesMask;
    }

    public RPropParameterUpdate setUpdatesMask(Vector updatesMask) {
        this.updatesMask = updatesMask;
        return this;
    }

    public RPropParameterUpdate setDeltas(Vector deltas) {
        this.deltas = deltas;
        return this;
    }

    private static RPropParameterUpdate sumLocal(List<RPropParameterUpdate> updates) {
        List nonNullUpdates = updates.stream().filter(Objects::nonNull).collect(Collectors.toList());
        if (nonNullUpdates.isEmpty()) {
            return null;
        }
        Vector newDeltas = ((RPropParameterUpdate)nonNullUpdates.get(nonNullUpdates.size() - 1)).deltas();
        Vector newGradient = ((RPropParameterUpdate)nonNullUpdates.get(nonNullUpdates.size() - 1)).prevIterationGradient();
        Vector totalUpdate = nonNullUpdates.stream().map(pu -> VectorUtils.elementWiseTimes(pu.updatesMask().copy(), pu.prevIterationUpdates())).reduce(Vector::plus).orElse(null);
        return new RPropParameterUpdate(totalUpdate, newGradient, newDeltas, new DenseVector(newDeltas.size()).assign(1.0));
    }

    private static RPropParameterUpdate sum(List<RPropParameterUpdate> updates) {
        Vector totalUpdate = updates.stream().filter(Objects::nonNull).map(pu -> VectorUtils.elementWiseTimes(pu.updatesMask().copy(), pu.prevIterationUpdates())).reduce(Vector::plus).orElse(null);
        Vector totalDelta = updates.stream().filter(Objects::nonNull).map(RPropParameterUpdate::deltas).reduce(Vector::plus).orElse(null);
        Vector totalGradient = updates.stream().filter(Objects::nonNull).map(RPropParameterUpdate::prevIterationGradient).reduce(Vector::plus).orElse(null);
        if (totalUpdate != null) {
            return new RPropParameterUpdate(totalUpdate, totalGradient, totalDelta, new DenseVector(Objects.requireNonNull(totalDelta).size()).assign(1.0));
        }
        return null;
    }

    private static RPropParameterUpdate avg(List<RPropParameterUpdate> updates) {
        List nonNullUpdates = updates.stream().filter(Objects::nonNull).collect(Collectors.toList());
        int size = nonNullUpdates.size();
        RPropParameterUpdate sum = RPropParameterUpdate.sum(updates);
        if (sum != null) {
            return sum.setPrevIterationGradient(sum.prevIterationGradient().divide(size)).setPrevIterationUpdates(sum.prevIterationUpdates().divide(size)).setDeltas(sum.deltas().divide(size));
        }
        return null;
    }
}

