/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.commonWalkingControlModules.capturePoint.lqrControl;

import java.util.List;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.AlgebraicS1Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.AlgebraicS2Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.AlgebraicS2Segment;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.LQRCommonValues;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.euclid.referenceFrame.interfaces.FramePoint3DReadOnly;
import us.ihmc.euclid.referenceFrame.interfaces.FrameTuple3DReadOnly;
import us.ihmc.euclid.referenceFrame.interfaces.FrameVector3DReadOnly;
import us.ihmc.euclid.tuple3D.interfaces.Tuple3DReadOnly;
import us.ihmc.robotics.math.trajectories.core.Polynomial3D;
import us.ihmc.robotics.math.trajectories.interfaces.Polynomial3DReadOnly;
import us.ihmc.yoVariables.euclid.referenceFrame.YoFramePoint3D;
import us.ihmc.yoVariables.euclid.referenceFrame.YoFrameVector3D;
import us.ihmc.yoVariables.providers.DoubleProvider;
import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoDouble;

public class LQRMomentumController {
    private static final double sufficientlyLarge = 1000.0;
    private final YoRegistry registry = new YoRegistry(this.getClass().getSimpleName());
    private final YoFrameVector3D yoK2 = new YoFrameVector3D("k2", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFrameVector3D feedbackForce = new YoFrameVector3D("feedbackForce", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D relativeCoMPosition = new YoFramePoint3D("relativeCoMPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFrameVector3D relativeCoMVelocity = new YoFrameVector3D("relativeCoMVelocity", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D finalVRPPosition = new YoFramePoint3D("finalVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D referenceVRPPosition = new YoFramePoint3D("referenceVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D feedbackVRPPosition = new YoFramePoint3D("feedbackVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoDouble omega = new YoDouble("omega", this.registry);
    static final double defaultVrpTrackingWeight = 100.0;
    static final double defaultMomentumRateWeight = 1.0E-4;
    private double vrpTrackingWeight = 100.0;
    private double momentumRateWeight = 1.0E-4;
    private final AlgebraicS1Function s1Function = new AlgebraicS1Function();
    private final AlgebraicS2Function s2Function = new AlgebraicS2Function();
    private final LQRCommonValues lqrCommonValues = new LQRCommonValues();
    private final DMatrixRMaj S1 = new DMatrixRMaj(6, 6);
    private final DMatrixRMaj s2 = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj K1 = new DMatrixRMaj(3, 6);
    private final DMatrixRMaj k2 = new DMatrixRMaj(3, 1);
    private final DMatrixRMaj u = new DMatrixRMaj(3, 1);
    private final DMatrixRMaj R1InverseDQ = new DMatrixRMaj(3, 3);
    private final DMatrixRMaj R1InverseBTranspose = new DMatrixRMaj(3, 6);
    private final DMatrixRMaj finalVRPState = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj relativeState = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj relativeDesiredVRP = new DMatrixRMaj(3, 1);
    private final RecyclingArrayList<Polynomial3D> relativeVRPTrajectories = new RecyclingArrayList(() -> new Polynomial3D(4));
    private boolean shouldUpdateS1 = true;
    private final DMatrixRMaj zeroVector = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj currentState = new DMatrixRMaj(6, 1);

    public LQRMomentumController(DoubleProvider omega) {
        this(omega, null);
    }

    public LQRMomentumController(DoubleProvider omega, YoRegistry parentRegistry) {
        this(omega.getValue(), parentRegistry);
    }

    public LQRMomentumController(double omega, YoRegistry parentRegistry) {
        this.omega.set(omega);
        this.computeDynamicsMatrix(this.omega.getDoubleValue());
        this.omega.addListener(v -> this.computeDynamicsMatrix(this.omega.getDoubleValue()));
        this.computeS1();
        if (parentRegistry != null) {
            parentRegistry.addChild(this.registry);
        }
    }

    public void setVRPTrackingWeight(double vrpTrackingWeight) {
        this.vrpTrackingWeight = vrpTrackingWeight;
        this.shouldUpdateS1 = true;
    }

    public void setMomentumRateWeight(double momentumRateWeight) {
        this.momentumRateWeight = momentumRateWeight;
        this.shouldUpdateS1 = true;
    }

    public void computeDynamicsMatrix(double omega) {
        this.lqrCommonValues.computeDynamicsMatrix(omega);
        this.shouldUpdateS1 = true;
    }

    public void setVRPTrajectory(List<? extends Polynomial3DReadOnly> vrpTrajectory) {
        this.relativeVRPTrajectories.clear();
        Polynomial3DReadOnly lastTrajectory = vrpTrajectory.get(vrpTrajectory.size() - 1);
        lastTrajectory.compute(Math.min(1000.0, lastTrajectory.getTimeInterval().getEndTime()));
        this.finalVRPPosition.set((Tuple3DReadOnly)lastTrajectory.getPosition());
        this.finalVRPPosition.get((DMatrix)this.finalVRPState);
        for (int i = 0; i < vrpTrajectory.size(); ++i) {
            Polynomial3DReadOnly trajectory = vrpTrajectory.get(i);
            Polynomial3D relativeTrajectory = (Polynomial3D)this.relativeVRPTrajectories.add();
            relativeTrajectory.set(trajectory);
            relativeTrajectory.shiftTrajectory(-this.finalVRPState.get(0, 0), -this.finalVRPState.get(1, 0), -this.finalVRPState.get(2, 0));
        }
    }

    void computeS1() {
        this.lqrCommonValues.computeEquivalentCostValues(this.momentumRateWeight, this.vrpTrackingWeight);
        this.s1Function.set(this.lqrCommonValues);
        this.s1Function.compute(0.0, this.S1);
        this.lqrCommonValues.computeS2ConstantStateMatrices(this.S1);
        CommonOps_DDRM.mult((double)-1.0, (DMatrix1Row)this.lqrCommonValues.getR1Inverse(), (DMatrix1Row)this.lqrCommonValues.getNb(), (DMatrix1Row)this.K1);
        CommonOps_DDRM.mult((DMatrix1Row)this.lqrCommonValues.getR1Inverse(), (DMatrix1Row)this.lqrCommonValues.getDQ(), (DMatrix1Row)this.R1InverseDQ);
        CommonOps_DDRM.multTransB((double)-0.5, (DMatrix1Row)this.lqrCommonValues.getR1Inverse(), (DMatrix1Row)this.lqrCommonValues.getB(), (DMatrix1Row)this.R1InverseBTranspose);
        this.shouldUpdateS1 = false;
    }

    void computeS2Parameters() {
        this.s2Function.set(this.zeroVector, (List<Polynomial3D>)this.relativeVRPTrajectories, this.lqrCommonValues);
    }

    void computeS2(double time) {
        int j = this.getSegmentNumber(time);
        double timeInSegment = this.computeTimeInSegment(time, j);
        ((Polynomial3D)this.relativeVRPTrajectories.get(j)).compute(timeInSegment);
        this.referenceVRPPosition.set((Tuple3DReadOnly)((Polynomial3D)this.relativeVRPTrajectories.get(j)).getPosition());
        this.referenceVRPPosition.get((DMatrix)this.relativeDesiredVRP);
        this.referenceVRPPosition.add((FrameTuple3DReadOnly)this.finalVRPPosition);
        this.s2Function.compute(j, timeInSegment, this.s2);
        CommonOps_DDRM.mult((DMatrix1Row)this.R1InverseDQ, (DMatrix1Row)this.relativeDesiredVRP, (DMatrix1Row)this.k2);
        CommonOps_DDRM.multAdd((DMatrix1Row)this.R1InverseBTranspose, (DMatrix1Row)this.s2, (DMatrix1Row)this.k2);
        this.yoK2.set((DMatrix)this.k2);
    }

    public AlgebraicS2Segment getS2Segment(int segmentNumber) {
        return this.s2Function.getSegment(segmentNumber);
    }

    public void computeControlInput(FramePoint3DReadOnly currentCoMPosition, FrameVector3DReadOnly currentCoMVelocity, double time) {
        currentCoMPosition.get((DMatrix)this.currentState);
        currentCoMVelocity.get(3, (DMatrix)this.currentState);
        this.computeControlInput(this.currentState, time);
    }

    public void computeControlInput(DMatrixRMaj currentState, double time) {
        if (this.shouldUpdateS1) {
            this.computeS1();
        }
        this.computeS2Parameters();
        this.computeS2(time);
        this.relativeState.set((DMatrixD1)currentState);
        for (int i = 0; i < 3; ++i) {
            this.relativeState.add(i, 0, -this.finalVRPState.get(i));
        }
        this.relativeCoMPosition.set((DMatrix)this.relativeState);
        this.relativeCoMVelocity.set(3, (DMatrix)this.relativeState);
        CommonOps_DDRM.mult((DMatrix1Row)this.K1, (DMatrix1Row)this.relativeState, (DMatrix1Row)this.u);
        this.feedbackForce.set((DMatrix)this.u);
        CommonOps_DDRM.addEquals((DMatrixD1)this.u, (DMatrixD1)this.k2);
        CommonOps_DDRM.mult((DMatrix1Row)this.lqrCommonValues.getC(), (DMatrix1Row)this.relativeState, (DMatrix1Row)this.relativeDesiredVRP);
        CommonOps_DDRM.multAdd((DMatrix1Row)this.lqrCommonValues.getD(), (DMatrix1Row)this.u, (DMatrix1Row)this.relativeDesiredVRP);
        this.feedbackVRPPosition.set((DMatrix)this.relativeDesiredVRP);
        this.feedbackVRPPosition.add((FrameTuple3DReadOnly)this.finalVRPPosition);
    }

    public FramePoint3DReadOnly getFeedbackVRPPosition() {
        return this.feedbackVRPPosition;
    }

    public FramePoint3DReadOnly getReferenceVRPPosition() {
        return this.referenceVRPPosition;
    }

    public DMatrixRMaj getU() {
        return this.u;
    }

    public DMatrixRMaj getCostHessian() {
        return this.S1;
    }

    public DMatrixRMaj getCostJacobian() {
        return this.s2;
    }

    public DMatrixRMaj getControlHessian() {
        return this.lqrCommonValues.getR1();
    }

    public DMatrixRMaj getControlJacobian() {
        return this.k2;
    }

    public DMatrixRMaj getStateDependentControlJacobian() {
        return this.lqrCommonValues.getNb();
    }

    private int getSegmentNumber(double time) {
        double timeToStart = 0.0;
        for (int i = 0; i < this.relativeVRPTrajectories.size(); ++i) {
            double segmentDuration = ((Polynomial3D)this.relativeVRPTrajectories.get(i)).getDuration();
            if (time - timeToStart <= segmentDuration) {
                return i;
            }
            timeToStart += segmentDuration;
        }
        return -1;
    }

    private double computeTimeInSegment(double time, int segment) {
        double timeOffset = 0.0;
        for (int i = 0; i < segment; ++i) {
            timeOffset += ((Polynomial3D)this.relativeVRPTrajectories.get(i)).getDuration();
        }
        return time - timeOffset;
    }

    public void setOmega(double omega) {
        this.omega.set(omega);
    }

    DMatrixRMaj getA() {
        return this.lqrCommonValues.getA();
    }

    DMatrixRMaj getB() {
        return this.lqrCommonValues.getB();
    }

    DMatrixRMaj getC() {
        return this.lqrCommonValues.getC();
    }

    DMatrixRMaj getD() {
        return this.lqrCommonValues.getD();
    }

    DMatrixRMaj getQ() {
        return this.lqrCommonValues.getQ();
    }

    DMatrixRMaj getR() {
        return this.lqrCommonValues.getR();
    }

    DMatrixRMaj getK1() {
        return this.K1;
    }

    DMatrixRMaj getK2() {
        return this.k2;
    }
}

