package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.util.PairList;
import java.util.List;

/* loaded from: input_file:ai/djl/pytorch/engine/PtNDArrayEx.class */
public class PtNDArrayEx implements NDArrayEx {
    private static final NDArrayIndexer INDEXER = new PtNDArrayIndexer();
    private PtNDArray array;

    /* JADX INFO: Access modifiers changed from: package-private */
    public PtNDArrayEx(PtNDArray ptNDArray) {
        this.array = ptNDArray;
    }

    /* renamed from: rdiv, reason: merged with bridge method [inline-methods] */
    public PtNDArray m170rdiv(Number number) {
        return m169rdiv(this.array.m135getManager().create(number));
    }

    /* renamed from: rdiv, reason: merged with bridge method [inline-methods] */
    public PtNDArray m169rdiv(NDArray nDArray) {
        return (PtNDArray) nDArray.div(this.array);
    }

    /* renamed from: rdivi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m168rdivi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rdivi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m167rdivi(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rsub, reason: merged with bridge method [inline-methods] */
    public PtNDArray m166rsub(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rsub, reason: merged with bridge method [inline-methods] */
    public PtNDArray m165rsub(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rsubi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m164rsubi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rsubi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m163rsubi(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rmod, reason: merged with bridge method [inline-methods] */
    public PtNDArray m162rmod(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rmod, reason: merged with bridge method [inline-methods] */
    public PtNDArray m161rmod(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rmodi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m160rmodi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rmodi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m159rmodi(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rpow, reason: merged with bridge method [inline-methods] */
    public PtNDArray m158rpow(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: rpowi, reason: merged with bridge method [inline-methods] */
    public PtNDArray m157rpowi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: relu, reason: merged with bridge method [inline-methods] */
    public PtNDArray m156relu() {
        return JniUtils.relu(this.array);
    }

    /* renamed from: sigmoid, reason: merged with bridge method [inline-methods] */
    public PtNDArray m155sigmoid() {
        return JniUtils.sigmoid(this.array);
    }

    /* renamed from: tanh, reason: merged with bridge method [inline-methods] */
    public PtNDArray m154tanh() {
        return JniUtils.tanh(this.array);
    }

    /* renamed from: softPlus, reason: merged with bridge method [inline-methods] */
    public PtNDArray m153softPlus() {
        return JniUtils.softPlus(this.array);
    }

    /* renamed from: softSign, reason: merged with bridge method [inline-methods] */
    public PtNDArray m152softSign() {
        return JniUtils.softSign(this.array);
    }

    /* renamed from: leakyRelu, reason: merged with bridge method [inline-methods] */
    public PtNDArray m151leakyRelu(float f) {
        return JniUtils.leakyRelu(this.array, f);
    }

    /* renamed from: elu, reason: merged with bridge method [inline-methods] */
    public PtNDArray m150elu(float f) {
        return JniUtils.elu(this.array, f);
    }

    /* renamed from: selu, reason: merged with bridge method [inline-methods] */
    public PtNDArray m149selu() {
        return JniUtils.selu(this.array);
    }

    /* renamed from: gelu, reason: merged with bridge method [inline-methods] */
    public PtNDArray m148gelu() {
        return JniUtils.gelu(this.array);
    }

    /* renamed from: maxPool, reason: merged with bridge method [inline-methods] */
    public PtNDArray m147maxPool(Shape shape, Shape shape2, Shape shape3, boolean z) {
        return JniUtils.maxPool(this.array, shape, shape2, shape3, z);
    }

    /* renamed from: globalMaxPool, reason: merged with bridge method [inline-methods] */
    public PtNDArray m146globalMaxPool() {
        PtNDArray adaptiveMaxPool = JniUtils.adaptiveMaxPool(this.array, getPoolShape(this.array));
        try {
            PtNDArray ptNDArray = (PtNDArray) adaptiveMaxPool.reshape(this.array.getShape().slice(0, 2));
            if (adaptiveMaxPool != null) {
                adaptiveMaxPool.close();
            }
            return ptNDArray;
        } catch (Throwable th) {
            if (adaptiveMaxPool != null) {
                try {
                    adaptiveMaxPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    /* renamed from: avgPool, reason: merged with bridge method [inline-methods] */
    public PtNDArray m145avgPool(Shape shape, Shape shape2, Shape shape3, boolean z, boolean z2) {
        return JniUtils.avgPool(this.array, shape, shape2, shape3, z, z2);
    }

    /* renamed from: globalAvgPool, reason: merged with bridge method [inline-methods] */
    public PtNDArray m144globalAvgPool() {
        PtNDArray adaptiveAvgPool = JniUtils.adaptiveAvgPool(this.array, getPoolShape(this.array));
        try {
            PtNDArray ptNDArray = (PtNDArray) adaptiveAvgPool.reshape(this.array.getShape().slice(0, 2));
            if (adaptiveAvgPool != null) {
                adaptiveAvgPool.close();
            }
            return ptNDArray;
        } catch (Throwable th) {
            if (adaptiveAvgPool != null) {
                try {
                    adaptiveAvgPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    /* renamed from: lpPool, reason: merged with bridge method [inline-methods] */
    public PtNDArray m143lpPool(float f, Shape shape, Shape shape2, Shape shape3, boolean z) {
        if (shape3.size() != 0) {
            throw new IllegalArgumentException("padding is not supported for PyTorch engine");
        }
        return JniUtils.lpPool(this.array, f, shape, shape2, z);
    }

    /* renamed from: globalLpPool, reason: merged with bridge method [inline-methods] */
    public PtNDArray m142globalLpPool(float f) {
        PtNDArray lpPool = JniUtils.lpPool(this.array, f, this.array.getShape().slice(2), getPoolShape(this.array), false);
        try {
            PtNDArray ptNDArray = (PtNDArray) lpPool.reshape(this.array.getShape().slice(0, 2));
            if (lpPool != null) {
                lpPool.close();
            }
            return ptNDArray;
        } catch (Throwable th) {
            if (lpPool != null) {
                try {
                    lpPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void adadeltaUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        throw new UnsupportedOperationException("AdaDelta optimzier is not supported for PyTorch engine!");
    }

    public void adagradUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void adamUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, boolean z) {
        JniUtils.adamUpdate((PtNDArray) nDList.get(0), (PtNDArray) nDList.get(1), (PtNDArray) nDList.get(2), (PtNDArray) nDList.get(3), f, f2, f3, f4, f5, f6, f7);
        JniUtils.zeroGrad((PtNDArray) nDList2.singletonOrThrow());
    }

    public void nagUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void rmspropUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, boolean z) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void sgdUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, boolean z) {
        JniUtils.sgdUpdate((PtNDArray) nDList.get(0), (PtNDArray) nDList.get(1), f5 == 0.0f ? null : (PtNDArray) nDList.get(2), f, f2, f3, f4, f5);
        JniUtils.zeroGrad((PtNDArray) nDList2.singletonOrThrow());
    }

    public NDList convolution(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, Shape shape, Shape shape2, Shape shape3, int i) {
        return new NDList(new NDArray[]{JniUtils.convolution((PtNDArray) nDArray, (PtNDArray) nDArray2, (PtNDArray) nDArray3, shape, shape2, shape3, i)});
    }

    public NDList linear(NDArray nDArray, NDArray nDArray2, NDArray nDArray3) {
        return new NDList(new NDArray[]{JniUtils.linear((PtNDArray) nDArray, (PtNDArray) nDArray2, (PtNDArray) nDArray3)});
    }

    public NDList embedding(NDList nDList, int i, int i2, boolean z, DataType dataType, PairList<String, Object> pairList) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList prelu(NDArray nDArray, NDArray nDArray2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList dropout(NDArray nDArray, float f, boolean z) {
        return new NDList(new NDArray[]{JniUtils.dropout((PtNDArray) nDArray, f, z)});
    }

    public NDList batchNorm(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, NDArray nDArray4, NDArray nDArray5, int i, float f, float f2, boolean z) {
        if (i == -1) {
            return new NDList(new NDArray[]{JniUtils.batchNorm((PtNDArray) nDArray, (PtNDArray) nDArray2, (PtNDArray) nDArray3, (PtNDArray) nDArray4, (PtNDArray) nDArray5, z, 1.0f - f, f2)});
        }
        NDManager newSubManager = nDArray.getManager().newSubManager();
        try {
            nDArray.attach(newSubManager);
            NDArray swapAxes = JniUtils.batchNorm((PtNDArray) nDArray.swapAxes(1, i), (PtNDArray) nDArray2, (PtNDArray) nDArray3, (PtNDArray) nDArray4, (PtNDArray) nDArray5, z, 1.0f - f, f2).swapAxes(1, i);
            nDArray.attach(newSubManager.getParentManager());
            swapAxes.attach(newSubManager.getParentManager());
            NDList nDList = new NDList(new NDArray[]{swapAxes});
            if (newSubManager != null) {
                newSubManager.close();
            }
            return nDList;
        } catch (Throwable th) {
            if (newSubManager != null) {
                try {
                    newSubManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public NDList rnn(NDList nDList, String str, long j, float f, int i, boolean z, boolean z2, boolean z3, PairList<String, Object> pairList) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList lstm(NDList nDList, long j, float f, int i, boolean z, boolean z2, boolean z3, double d, double d2, PairList<String, Object> pairList) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v39, types: [ai.djl.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v41, types: [ai.djl.ndarray.NDArray] */
    /* JADX WARN: Type inference failed for: r0v43, types: [ai.djl.ndarray.NDArray] */
    /* renamed from: resize, reason: merged with bridge method [inline-methods] */
    public PtNDArray m141resize(int i, int i2, int i3) {
        NDManager mo173newSubManager = this.array.m135getManager().mo173newSubManager();
        try {
            this.array.attach(mo173newSubManager);
            PtNDArray ptNDArray = this.array;
            if (ptNDArray.isEmpty()) {
                throw new IllegalArgumentException("attempt to resize of an empty NDArray");
            }
            if (ptNDArray.getDataType() != DataType.FLOAT32) {
                ptNDArray = ptNDArray.toType(DataType.FLOAT32, true);
            }
            int dimension = ptNDArray.getShape().dimension();
            if (dimension == 3) {
                ptNDArray = ptNDArray.expandDims(0);
            }
            PtNDArray m17transpose = JniUtils.interpolate((PtNDArray) ptNDArray.transpose(new int[]{0, 3, 1, 2}), new long[]{i2, i}, getInterpolationMode(i3), false).m17transpose(0, 2, 3, 1);
            if (dimension == 3) {
                m17transpose = m17transpose.squeeze(0);
            }
            this.array.attach(mo173newSubManager.getParentManager());
            m17transpose.attach(mo173newSubManager.getParentManager());
            PtNDArray ptNDArray2 = m17transpose;
            if (mo173newSubManager != null) {
                mo173newSubManager.close();
            }
            return ptNDArray2;
        } catch (Throwable th) {
            if (mo173newSubManager != null) {
                try {
                    mo173newSubManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public NDArray randomFlipLeftRight() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomFlipTopBottom() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomBrightness(float f) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomHue(float f) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray randomColorJitter(float f, float f2, float f3, float f4) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArrayIndexer getIndexer() {
        return INDEXER;
    }

    /* renamed from: where, reason: merged with bridge method [inline-methods] */
    public PtNDArray m140where(NDArray nDArray, NDArray nDArray2) {
        if (nDArray.getShape().equals(this.array.getShape())) {
            return JniUtils.where((PtNDArray) nDArray, this.array, (PtNDArray) nDArray2);
        }
        throw new UnsupportedOperationException("condition and self shape mismatch, broadcast is not supported");
    }

    /* renamed from: stack, reason: merged with bridge method [inline-methods] */
    public PtNDArray m139stack(NDList nDList, int i) {
        NDArray[] nDArrayArr = new NDArray[nDList.size() + 1];
        nDArrayArr[0] = this.array;
        System.arraycopy(nDList.toArray(new NDArray[0]), 0, nDArrayArr, 1, nDList.size());
        return JniUtils.stack(nDArrayArr, i);
    }

    /* renamed from: concat, reason: merged with bridge method [inline-methods] */
    public PtNDArray m138concat(NDList nDList, int i) {
        NDUtils.checkConcatInput(nDList);
        NDArray[] nDArrayArr = new NDArray[nDList.size() + 1];
        nDArrayArr[0] = this.array;
        System.arraycopy(nDList.toArray(new NDArray[0]), 0, nDArrayArr, 1, nDList.size());
        return JniUtils.cat(nDArrayArr, i);
    }

    public NDList multiBoxTarget(NDList nDList, float f, float f2, float f3, float f4, int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList multiBoxPrior(List<Float> list, List<Float> list2, List<Float> list3, List<Float> list4, boolean z) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList multiBoxDetection(NDList nDList, boolean z, float f, int i, float f2, boolean z2, int i2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /* renamed from: getArray, reason: merged with bridge method [inline-methods] */
    public PtNDArray m137getArray() {
        return this.array;
    }

    private Shape getPoolShape(NDArray nDArray) {
        switch (nDArray.getShape().dimension() - 2) {
            case 1:
                return new Shape(new long[]{1});
            case 2:
                return new Shape(new long[]{1, 1});
            case 3:
                return new Shape(new long[]{1, 1, 1});
            default:
                throw new IllegalArgumentException("the input dimension should be in [3, 5]");
        }
    }

    private int getInterpolationMode(int i) {
        switch (i) {
            case 0:
                return 0;
            case 1:
                return 2;
            case 2:
                return 5;
            case 3:
                return 3;
            default:
                throw new UnsupportedOperationException("The kind of interpolation is not supported.");
        }
    }
}
