/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.optimization.updatecalculators;

import java.io.Serializable;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;

public class SimpleGDParameterUpdate
implements Serializable {
    private static final long serialVersionUID = -8732955283436005621L;
    private Vector gradient;

    public SimpleGDParameterUpdate(int paramsCnt) {
        this.gradient = new DenseVector(paramsCnt);
    }

    public SimpleGDParameterUpdate(Vector gradient) {
        this.gradient = gradient;
    }

    public Vector gradient() {
        return this.gradient;
    }

    public static SimpleGDParameterUpdate sumLocal(List<SimpleGDParameterUpdate> updates) {
        Vector accumulatedGrad = updates.stream().filter(Objects::nonNull).map(SimpleGDParameterUpdate::gradient).reduce(Vector::plus).orElse(null);
        return accumulatedGrad != null ? new SimpleGDParameterUpdate(accumulatedGrad) : null;
    }

    public static SimpleGDParameterUpdate avg(List<SimpleGDParameterUpdate> updates) {
        SimpleGDParameterUpdate sum = SimpleGDParameterUpdate.sumLocal(updates);
        return sum != null ? new SimpleGDParameterUpdate(sum.gradient().divide(updates.stream().filter(Objects::nonNull).collect(Collectors.toList()).size())) : null;
    }
}

