/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.commonWalkingControlModules.capturePoint.lqrControl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.AlgebraicS1Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.AlgebraicS2Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.DifferentialS1Segment;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.DifferentialS2Segment;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.FlightS1Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.FlightS2Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.LQRCommonValues;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.S1Function;
import us.ihmc.commonWalkingControlModules.capturePoint.lqrControl.S2Segment;
import us.ihmc.commonWalkingControlModules.controllerCore.command.inverseDynamics.LinearMomentumRateCostCommand;
import us.ihmc.commonWalkingControlModules.dynamicPlanning.comPlanning.ContactStateProvider;
import us.ihmc.commonWalkingControlModules.dynamicPlanning.comPlanning.SettableContactStateProvider;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.euclid.referenceFrame.interfaces.FramePoint3DReadOnly;
import us.ihmc.euclid.referenceFrame.interfaces.FrameTuple3DReadOnly;
import us.ihmc.euclid.tuple3D.interfaces.Tuple3DReadOnly;
import us.ihmc.robotics.math.trajectories.core.Polynomial3D;
import us.ihmc.robotics.math.trajectories.interfaces.Polynomial3DBasics;
import us.ihmc.robotics.math.trajectories.interfaces.Polynomial3DReadOnly;
import us.ihmc.yoVariables.euclid.referenceFrame.YoFramePoint3D;
import us.ihmc.yoVariables.euclid.referenceFrame.YoFrameVector3D;
import us.ihmc.yoVariables.providers.DoubleProvider;
import us.ihmc.yoVariables.registry.YoRegistry;

public class LQRJumpMomentumController {
    private static final double discreteDt = 0.005;
    private static final double gravityZ = -9.81;
    private final double totalMass;
    private final YoRegistry registry = new YoRegistry(this.getClass().getSimpleName());
    private final YoFrameVector3D yoK2 = new YoFrameVector3D("k2", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFrameVector3D feedbackForce = new YoFrameVector3D("feedbackForce", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D relativeCoMPosition = new YoFramePoint3D("relativeCoMPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFrameVector3D relativeCoMVelocity = new YoFrameVector3D("relativeCoMVelocity", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D finalVRPPosition = new YoFramePoint3D("finalVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D referenceVRPPosition = new YoFramePoint3D("referenceVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
    private final YoFramePoint3D feedbackVRPPosition = new YoFramePoint3D("feedbackVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
    static final double defaultVrpTrackingWeight = 1000000.0;
    static final double defaultMomentumRateWeight = 1.0E-5;
    private double vrpTrackingWeight = 1000000.0;
    private double momentumRateWeight = 1.0E-5;
    private final LQRCommonValues lqrCommonValues = new LQRCommonValues();
    private final AlgebraicS1Function finalS1Function = new AlgebraicS1Function();
    private final DMatrixRMaj Nb = new DMatrixRMaj(3, 6);
    private final DMatrixRMaj S1 = new DMatrixRMaj(6, 6);
    private final DMatrixRMaj s2 = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj K1 = new DMatrixRMaj(3, 6);
    private final DMatrixRMaj k2 = new DMatrixRMaj(3, 1);
    private final DMatrixRMaj u = new DMatrixRMaj(3, 1);
    private final DMatrixRMaj R1InverseDQ = new DMatrixRMaj(3, 3);
    private final DMatrixRMaj R1InverseBTranspose = new DMatrixRMaj(3, 6);
    private final DMatrixRMaj finalVRPState = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj relativeState = new DMatrixRMaj(6, 1);
    private final DMatrixRMaj relativeDesiredVRP = new DMatrixRMaj(3, 1);
    private final DMatrixRMaj linearMomentumRateGradient = new DMatrixRMaj(1, 3);
    private final DMatrixRMaj linearMomentumRateHessian = new DMatrixRMaj(3, 3);
    final RecyclingArrayList<Polynomial3D> relativeVRPTrajectories = new RecyclingArrayList(() -> new Polynomial3D(4));
    final RecyclingArrayList<SettableContactStateProvider> contactStateProviders = new RecyclingArrayList(SettableContactStateProvider::new);
    private boolean shouldUpdateP = true;
    private boolean shouldUpdateCosts = true;
    private final HashMap<Polynomial3D, S1Function> s1Functions = new HashMap();
    private final List<S1Function> reversedS1FunctionList = new ArrayList<S1Function>();
    private final List<S1Function> s1FunctionList = new ArrayList<S1Function>();
    private final List<S2Segment> reversedS2FunctionList = new ArrayList<S2Segment>();
    private final List<S2Segment> s2FunctionList = new ArrayList<S2Segment>();
    private final LinearMomentumRateCostCommand momentumRateCostCommand = new LinearMomentumRateCostCommand();

    public LQRJumpMomentumController(DoubleProvider omega, double totalMass) {
        this(omega, totalMass, null);
    }

    public LQRJumpMomentumController(DoubleProvider omega, double totalMass, YoRegistry parentRegistry) {
        this.totalMass = totalMass;
        this.computeDynamicsMatrix(omega.getValue());
        this.computeP();
        if (parentRegistry != null) {
            parentRegistry.addChild(this.registry);
        }
    }

    public void setVRPTrackingWeight(double vrpTrackingWeight) {
        this.vrpTrackingWeight = vrpTrackingWeight;
        this.shouldUpdateP = true;
    }

    public void setMomentumRateWeight(double momentumRateWeight) {
        this.momentumRateWeight = momentumRateWeight;
        this.shouldUpdateP = true;
    }

    private void computeDynamicsMatrix(double omega) {
        this.lqrCommonValues.computeDynamicsMatrix(omega);
        this.shouldUpdateP = true;
    }

    public void setVRPTrajectory(List<? extends Polynomial3DReadOnly> vrpTrajectory, List<? extends ContactStateProvider> contactStateProviders) {
        if (vrpTrajectory.size() != contactStateProviders.size()) {
            throw new IllegalArgumentException("The contacts don't match the trajectory.");
        }
        this.relativeVRPTrajectories.clear();
        this.contactStateProviders.clear();
        Polynomial3DReadOnly lastTrajectory = vrpTrajectory.get(vrpTrajectory.size() - 1);
        lastTrajectory.compute(Math.min(10.0, lastTrajectory.getTimeInterval().getEndTime()));
        this.finalVRPPosition.set((Tuple3DReadOnly)lastTrajectory.getPosition());
        this.finalVRPPosition.get((DMatrix)this.finalVRPState);
        for (int i = 0; i < vrpTrajectory.size(); ++i) {
            Polynomial3DReadOnly trajectory = vrpTrajectory.get(i);
            Polynomial3DBasics relativeTrajectory = (Polynomial3DBasics)this.relativeVRPTrajectories.add();
            relativeTrajectory.set(trajectory);
            relativeTrajectory.shiftTrajectory(-this.finalVRPPosition.getX(), -this.finalVRPPosition.getY(), -this.finalVRPPosition.getZ());
            ((SettableContactStateProvider)this.contactStateProviders.add()).set(contactStateProviders.get(i));
        }
        if (this.shouldUpdateP) {
            this.computeP();
        }
        this.computeS1Segments();
        this.computeS2Segments();
    }

    void computeP() {
        this.lqrCommonValues.computeEquivalentCostValues(this.momentumRateWeight, this.vrpTrackingWeight);
        this.finalS1Function.set(this.lqrCommonValues);
        this.shouldUpdateP = false;
    }

    void computeS1Segments() {
        this.s1Functions.clear();
        this.reversedS1FunctionList.clear();
        this.s1FunctionList.clear();
        int numberOfSegments = this.relativeVRPTrajectories.size() - 1;
        if (numberOfSegments < 0) {
            this.reversedS1FunctionList.add(this.finalS1Function);
        } else {
            Polynomial3D nextVRPTrajectory = (Polynomial3D)this.relativeVRPTrajectories.get(numberOfSegments);
            this.finalS1Function.compute(0.0, this.S1);
            this.s1Functions.put(nextVRPTrajectory, this.finalS1Function);
            this.reversedS1FunctionList.add(this.finalS1Function);
            boolean hasHadSwitch = false;
            for (int j = numberOfSegments - 1; j >= 0; --j) {
                Polynomial3D thisVRPTrajectory = (Polynomial3D)this.relativeVRPTrajectories.get(j);
                if (((SettableContactStateProvider)this.contactStateProviders.get(j)).getContactState().isLoadBearing()) {
                    S1Function thisS1Trajectory;
                    if (hasHadSwitch) {
                        DifferentialS1Segment s1Segment = new DifferentialS1Segment(0.005);
                        s1Segment.set(this.lqrCommonValues, this.S1, thisVRPTrajectory.getTimeInterval().getDuration());
                        thisS1Trajectory = s1Segment;
                    } else {
                        thisS1Trajectory = this.finalS1Function;
                    }
                    thisS1Trajectory.compute(0.0, this.S1);
                    this.s1Functions.put(thisVRPTrajectory, thisS1Trajectory);
                    this.reversedS1FunctionList.add(thisS1Trajectory);
                    continue;
                }
                hasHadSwitch = true;
                FlightS1Function s1Function = new FlightS1Function();
                s1Function.set(this.S1, thisVRPTrajectory.getDuration());
                s1Function.compute(0.0, this.S1);
                this.s1Functions.put(thisVRPTrajectory, s1Function);
                this.reversedS1FunctionList.add(s1Function);
            }
        }
        for (int i = this.reversedS1FunctionList.size() - 1; i >= 0; --i) {
            this.s1FunctionList.add(this.reversedS1FunctionList.get(i));
        }
    }

    void computeS2Segments() {
        int j;
        this.reversedS2FunctionList.clear();
        this.s2FunctionList.clear();
        int numberOfSegments = this.relativeVRPTrajectories.size();
        int numberOfEndingContactSegments = 0;
        for (j = numberOfSegments - 1; j >= 0 && ((SettableContactStateProvider)this.contactStateProviders.get(j)).getContactState().isLoadBearing(); --j) {
            ++numberOfEndingContactSegments;
        }
        ArrayList<Polynomial3D> endingContactVRPs = new ArrayList<Polynomial3D>();
        for (j = numberOfSegments - numberOfEndingContactSegments; j < numberOfSegments; ++j) {
            endingContactVRPs.add((Polynomial3D)this.relativeVRPTrajectories.get(j));
        }
        this.s2.zero();
        AlgebraicS2Function endingS2Function = new AlgebraicS2Function();
        this.finalS1Function.compute(0.0, this.S1);
        this.lqrCommonValues.computeS2ConstantStateMatrices(this.S1);
        endingS2Function.set(this.s2, endingContactVRPs, this.lqrCommonValues);
        endingS2Function.compute(0, 0.0, this.s2);
        for (j = numberOfEndingContactSegments - 1; j >= 0; --j) {
            this.reversedS2FunctionList.add(endingS2Function.getSegment(j));
        }
        for (j = numberOfSegments - numberOfEndingContactSegments - 1; j >= 0; --j) {
            Polynomial3D trajectorySegment = (Polynomial3D)this.relativeVRPTrajectories.get(j);
            if (((SettableContactStateProvider)this.contactStateProviders.get(j)).getContactState().isLoadBearing()) {
                DifferentialS2Segment s2Segment = new DifferentialS2Segment(0.005);
                s2Segment.set(this.s1Functions.get(trajectorySegment), trajectorySegment, this.lqrCommonValues, this.s2);
                s2Segment.compute(0.0, this.s2);
                this.reversedS2FunctionList.add(s2Segment);
                continue;
            }
            this.s1Functions.get(trajectorySegment).compute(trajectorySegment.getTimeInterval().getDuration(), this.S1);
            FlightS2Function s2Function = new FlightS2Function(-9.81);
            s2Function.set(this.S1, this.s2, trajectorySegment.getTimeInterval().getDuration());
            s2Function.compute(0.0, this.s2);
            this.reversedS2FunctionList.add(s2Function);
        }
        for (int i = this.reversedS2FunctionList.size() - 1; i >= 0; --i) {
            this.s2FunctionList.add(this.reversedS2FunctionList.get(i));
        }
    }

    void computeS1AndK1(double time) {
        int segmentNumber = this.getSegmentNumber(time);
        double timeInState = this.computeTimeInSegment(time, segmentNumber);
        this.s1FunctionList.get(segmentNumber).compute(timeInState, this.S1);
        this.Nb.set((DMatrixD1)this.lqrCommonValues.getNTranspose());
        CommonOps_DDRM.multAddTransA((DMatrix1Row)this.lqrCommonValues.getB(), (DMatrix1Row)this.S1, (DMatrix1Row)this.Nb);
        CommonOps_DDRM.mult((double)-1.0, (DMatrix1Row)this.lqrCommonValues.getR1Inverse(), (DMatrix1Row)this.Nb, (DMatrix1Row)this.K1);
    }

    void computeS2AndK2(double time) {
        int j = this.getSegmentNumber(time);
        double timeInSegment = Math.min(10.0, this.computeTimeInSegment(time, j));
        ((Polynomial3D)this.relativeVRPTrajectories.get(j)).compute(timeInSegment);
        this.referenceVRPPosition.set((Tuple3DReadOnly)((Polynomial3D)this.relativeVRPTrajectories.get(j)).getPosition());
        this.referenceVRPPosition.get((DMatrix)this.relativeDesiredVRP);
        this.referenceVRPPosition.add((FrameTuple3DReadOnly)this.finalVRPPosition);
        this.s2FunctionList.get(j).compute(timeInSegment, this.s2);
        CommonOps_DDRM.mult((DMatrix1Row)this.lqrCommonValues.getR1Inverse(), (DMatrix1Row)this.lqrCommonValues.getDQ(), (DMatrix1Row)this.R1InverseDQ);
        CommonOps_DDRM.multTransB((double)-0.5, (DMatrix1Row)this.lqrCommonValues.getR1Inverse(), (DMatrix1Row)this.lqrCommonValues.getB(), (DMatrix1Row)this.R1InverseBTranspose);
        CommonOps_DDRM.mult((DMatrix1Row)this.R1InverseDQ, (DMatrix1Row)this.relativeDesiredVRP, (DMatrix1Row)this.k2);
        CommonOps_DDRM.multAdd((DMatrix1Row)this.R1InverseBTranspose, (DMatrix1Row)this.s2, (DMatrix1Row)this.k2);
        this.yoK2.set((DMatrix)this.k2);
    }

    public void computeControlInput(DMatrixRMaj currentState, double time) {
        this.shouldUpdateCosts = true;
        this.computeS1AndK1(time);
        this.computeS2AndK2(time);
        this.relativeState.set((DMatrixD1)currentState);
        for (int i = 0; i < 3; ++i) {
            this.relativeState.add(i, 0, -this.finalVRPState.get(i));
        }
        this.relativeCoMPosition.set((DMatrix)this.relativeState);
        this.relativeCoMVelocity.set(3, (DMatrix)this.relativeState);
        CommonOps_DDRM.mult((DMatrix1Row)this.K1, (DMatrix1Row)this.relativeState, (DMatrix1Row)this.u);
        this.feedbackForce.set((DMatrix)this.u);
        CommonOps_DDRM.addEquals((DMatrixD1)this.u, (DMatrixD1)this.k2);
        CommonOps_DDRM.mult((DMatrix1Row)this.lqrCommonValues.getC(), (DMatrix1Row)this.relativeState, (DMatrix1Row)this.relativeDesiredVRP);
        CommonOps_DDRM.multAdd((DMatrix1Row)this.lqrCommonValues.getD(), (DMatrix1Row)this.u, (DMatrix1Row)this.relativeDesiredVRP);
        this.feedbackVRPPosition.set((DMatrix)this.relativeDesiredVRP);
        this.feedbackVRPPosition.add((FrameTuple3DReadOnly)this.finalVRPPosition);
        double massInverse = 1.0 / this.totalMass;
        CommonOps_DDRM.multTransA((double)(-2.0 * massInverse), (DMatrix1Row)this.k2, (DMatrix1Row)this.lqrCommonValues.getR1(), (DMatrix1Row)this.linearMomentumRateGradient);
        CommonOps_DDRM.multAddTransAB((double)(2.0 * massInverse), (DMatrix1Row)this.relativeState, (DMatrix1Row)this.Nb, (DMatrix1Row)this.linearMomentumRateGradient);
        CommonOps_DDRM.scale((double)(2.0 * massInverse * massInverse), (DMatrixD1)this.lqrCommonValues.getR1(), (DMatrixD1)this.linearMomentumRateHessian);
        this.momentumRateCostCommand.setLinearMomentumRateGradient(this.linearMomentumRateGradient);
        this.momentumRateCostCommand.setLinearMomentumRateHessian(this.linearMomentumRateHessian);
    }

    public DMatrixRMaj getU() {
        return this.u;
    }

    public DMatrixRMaj getCostHessian() {
        return this.S1;
    }

    public DMatrixRMaj getCostJacobian() {
        return this.s2;
    }

    private int getSegmentNumber(double time) {
        double timeToStart = 0.0;
        for (int i = 0; i < this.relativeVRPTrajectories.size(); ++i) {
            double segmentDuration = ((Polynomial3D)this.relativeVRPTrajectories.get(i)).getDuration();
            if (time - timeToStart <= segmentDuration) {
                return i;
            }
            timeToStart += segmentDuration;
        }
        return -1;
    }

    private double computeTimeInSegment(double time, int segment) {
        double timeOffset = 0.0;
        for (int i = 0; i < segment; ++i) {
            timeOffset += ((Polynomial3D)this.relativeVRPTrajectories.get(i)).getDuration();
        }
        return time - timeOffset;
    }

    DMatrixRMaj getA() {
        return this.lqrCommonValues.getA();
    }

    DMatrixRMaj getB() {
        return this.lqrCommonValues.getB();
    }

    DMatrixRMaj getC() {
        return this.lqrCommonValues.getC();
    }

    DMatrixRMaj getD() {
        return this.lqrCommonValues.getD();
    }

    DMatrixRMaj getQ() {
        return this.lqrCommonValues.getQ();
    }

    DMatrixRMaj getR() {
        return this.lqrCommonValues.getR();
    }

    DMatrixRMaj getK1() {
        return this.K1;
    }

    DMatrixRMaj getK2() {
        return this.k2;
    }

    S1Function getS1Segment(int segment) {
        return this.s1FunctionList.get(segment);
    }

    S2Segment getS2Segment(int segment) {
        return this.s2FunctionList.get(segment);
    }

    public FramePoint3DReadOnly getFeedbackVRPPosition() {
        return this.feedbackVRPPosition;
    }

    private void computeMomentumRateCostCommand() {
        double massInverse = 1.0 / this.totalMass;
        CommonOps_DDRM.multTransA((double)(-2.0 * massInverse), (DMatrix1Row)this.k2, (DMatrix1Row)this.lqrCommonValues.getR1(), (DMatrix1Row)this.linearMomentumRateGradient);
        CommonOps_DDRM.multAddTransAB((double)(2.0 * massInverse), (DMatrix1Row)this.relativeState, (DMatrix1Row)this.Nb, (DMatrix1Row)this.linearMomentumRateGradient);
        CommonOps_DDRM.scale((double)(2.0 * massInverse * massInverse), (DMatrixD1)this.lqrCommonValues.getR1(), (DMatrixD1)this.linearMomentumRateHessian);
        this.momentumRateCostCommand.setLinearMomentumRateGradient(this.linearMomentumRateGradient);
        this.momentumRateCostCommand.setLinearMomentumRateHessian(this.linearMomentumRateHessian);
    }

    public LinearMomentumRateCostCommand getMomentumRateCostCommand() {
        if (this.shouldUpdateCosts) {
            this.shouldUpdateCosts = false;
            this.computeMomentumRateCostCommand();
        }
        return this.momentumRateCostCommand;
    }
}

