package ai.djl.pytorch.jni;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNull;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.engine.PtDeviceType;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/pytorch/jni/JniUtils.class */
public final class JniUtils {
    private static final Logger logger = LoggerFactory.getLogger(JniUtils.class);
    private static Set<String> configs;
    private static final int NULL_PTR = 0;
    private static final int BYTE_LENGTH = 4194304;

    private JniUtils() {
    }

    private static int layoutMapper(SparseFormat sparseFormat, Device device) {
        if (sparseFormat != SparseFormat.DENSE) {
            if (sparseFormat == SparseFormat.COO) {
                return 1;
            }
            throw new IllegalArgumentException("Current PyTorch only support SparseFormat.DENSE and SparseFormat.COO");
        }
        if (!Boolean.getBoolean("ai.djl.pytorch.use_mkldnn") || device.equals(Device.gpu())) {
            return NULL_PTR;
        }
        return 2;
    }

    public static boolean isGradMode() {
        return PyTorchLibrary.LIB.torchIsGradMode();
    }

    public static void setGradMode(boolean z) {
        PyTorchLibrary.LIB.torchSetGradMode(z);
    }

    public static int getNumInteropThreads() {
        return PyTorchLibrary.LIB.torchGetNumInteropThreads();
    }

    public static int getNumThreads() {
        return PyTorchLibrary.LIB.torchGetNumThreads();
    }

    public static void setNumInteropThreads(int i) {
        PyTorchLibrary.LIB.torchSetNumInteropThreads(i);
    }

    public static void setNumThreads(int i) {
        PyTorchLibrary.LIB.torchSetNumThreads(i);
    }

    public static void setBenchmarkCuDNN(boolean z) {
        PyTorchLibrary.LIB.torchSetBenchmarkCuDNN(z);
    }

    public static synchronized Set<String> getFeatures() {
        if (configs != null) {
            return configs;
        }
        HashSet hashSet = new HashSet();
        PyTorchLibrary.LIB.torchShowConfig(hashSet);
        configs = hashSet;
        return configs;
    }

    public static void setSeed(long j) {
        PyTorchLibrary.LIB.torchManualSeed(j);
    }

    public static synchronized void startProfile(boolean z, boolean z2, boolean z3) {
        PyTorchLibrary.LIB.torchStartProfile(z, z2, z3);
    }

    public static synchronized void stopProfile(String str) {
        PyTorchLibrary.LIB.torchStopProfile(str);
    }

    public static PtNDArray createNdFromByteBuffer(PtNDManager ptNDManager, ByteBuffer byteBuffer, Shape shape, DataType dataType, SparseFormat sparseFormat, Device device) {
        int layoutMapper = layoutMapper(sparseFormat, device);
        long j = PyTorchLibrary.LIB.torchFromBlob(byteBuffer, shape.getShape(), dataType.ordinal(), layoutMapper, new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false);
        return (layoutMapper == 1 || layoutMapper == 2 || device.isGpu()) ? new PtNDArray(ptNDManager, j) : new PtNDArray(ptNDManager, j, byteBuffer);
    }

    public static void emptyCudaCache() {
        PyTorchLibrary.LIB.torchCudaEmptyCache();
    }

    public static PtNDArray createEmptyNdArray(PtNDManager ptNDManager, Shape shape, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchEmpty(shape.getShape(), dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray createZerosNdArray(PtNDManager ptNDManager, Shape shape, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchZeros(shape.getShape(), dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray createOnesNdArray(PtNDManager ptNDManager, Shape shape, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchOnes(shape.getShape(), dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray full(PtNDManager ptNDManager, Shape shape, double d, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchFull(shape.getShape(), d, dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray zerosLike(PtNDArray ptNDArray, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchZerosLike(((Long) ptNDArray.getHandle()).longValue(), dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray onesLike(PtNDArray ptNDArray, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchOnesLike(((Long) ptNDArray.getHandle()).longValue(), dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray arange(PtNDManager ptNDManager, float f, float f2, float f3, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchArange(f, f2, f3, dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray linspace(PtNDManager ptNDManager, float f, float f2, int i, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchLinspace(f, f2, i, dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray createSparseCoo(PtNDArray ptNDArray, PtNDArray ptNDArray2, Shape shape) {
        return new PtNDArray(ptNDArray2.m138getManager(), PyTorchLibrary.LIB.torchSparseCoo(shape.getShape(), ((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), false));
    }

    public static PtNDArray to(PtNDArray ptNDArray, DataType dataType, Device device) {
        PtNDManager m138getManager = ptNDArray.m138getManager();
        if (!device.equals(m138getManager.getDevice())) {
            m138getManager = m138getManager.mo175newSubManager(device);
        }
        return new PtNDArray(m138getManager, PyTorchLibrary.LIB.torchTo(((Long) ptNDArray.getHandle()).longValue(), dataType.ordinal(), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}));
    }

    public static PtNDArray toSparse(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchToSparse(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray toDense(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchToDense(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray broadcast(PtNDArray ptNDArray, Shape shape) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchExpand(((Long) ptNDArray.getHandle()).longValue(), shape.getShape()));
    }

    public static PtNDArray slice(PtNDArray ptNDArray, long j, long j2, long j3, long j4) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSlice(((Long) ptNDArray.getHandle()).longValue(), j, j2, j3, j4));
    }

    public static PtNDArray index(PtNDArray ptNDArray, long[] jArr, long[] jArr2, long[] jArr3, PtNDManager ptNDManager) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchIndex(((Long) ptNDArray.getHandle()).longValue(), jArr, jArr2, jArr3));
    }

    public static PtNDArray indexAdv(PtNDArray ptNDArray, NDIndex nDIndex, PtNDManager ptNDManager) {
        if (ptNDArray == null) {
            return ptNDArray;
        }
        List indices = nDIndex.getIndices();
        long j = PyTorchLibrary.LIB.torchIndexInit(indices.size());
        try {
            ListIterator listIterator = indices.listIterator();
            while (listIterator.hasNext()) {
                if (listIterator.nextIndex() == nDIndex.getEllipsisIndex()) {
                    PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(j, true);
                }
                NDIndexSlice nDIndexSlice = (NDIndexElement) listIterator.next();
                if (nDIndexSlice instanceof NDIndexNull) {
                    PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(j, false);
                } else if (nDIndexSlice instanceof NDIndexSlice) {
                    Long min = nDIndexSlice.getMin();
                    Long max = nDIndexSlice.getMax();
                    Long step = nDIndexSlice.getStep();
                    PyTorchLibrary.LIB.torchIndexAppendSlice(j, min == null ? -1L : min.longValue(), max == null ? -1L : max.longValue(), step == null ? 1L : step.longValue(), ((min == null ? 1 : NULL_PTR) * 2) + (max == null ? 1 : NULL_PTR));
                } else if (nDIndexSlice instanceof NDIndexAll) {
                    PyTorchLibrary.LIB.torchIndexAppendSlice(j, -1L, -1L, 1L, 3);
                } else if (nDIndexSlice instanceof NDIndexFixed) {
                    PyTorchLibrary.LIB.torchIndexAppendFixed(j, ((NDIndexFixed) nDIndexSlice).getIndex());
                } else if (nDIndexSlice instanceof NDIndexBooleans) {
                    PyTorchLibrary.LIB.torchIndexAppendArray(j, ((Long) ((PtNDArray) ((NDIndexBooleans) nDIndexSlice).getIndex()).getHandle()).longValue());
                } else if (nDIndexSlice instanceof NDIndexTake) {
                    PtNDArray mo177from = ptNDManager.mo177from(((NDIndexTake) nDIndexSlice).getIndex());
                    if (mo177from.getDataType() != DataType.INT64) {
                        mo177from = mo177from.m136toType(DataType.INT64, true);
                    }
                    PyTorchLibrary.LIB.torchIndexAppendArray(j, ((Long) mo177from.getHandle()).longValue());
                } else if (nDIndexSlice instanceof NDIndexPick) {
                    PtNDArray pick = pick(ptNDArray, ptNDManager.mo177from(((NDIndexFullPick) NDIndexFullPick.fromIndex(nDIndex, ptNDArray.getShape()).get()).getIndices()), r0.getAxis());
                    PyTorchLibrary.LIB.torchDeleteIndex(j);
                    return pick;
                }
            }
            if (indices.size() == nDIndex.getEllipsisIndex()) {
                PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(j, true);
            }
            PtNDArray ptNDArray2 = new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchIndexAdvGet(((Long) ptNDArray.getHandle()).longValue(), j));
            PyTorchLibrary.LIB.torchDeleteIndex(j);
            return ptNDArray2;
        } catch (Throwable th) {
            PyTorchLibrary.LIB.torchDeleteIndex(j);
            throw th;
        }
    }

    public static void indexAdvPut(PtNDArray ptNDArray, NDIndex nDIndex, PtNDArray ptNDArray2) {
        if (ptNDArray == null) {
            return;
        }
        List indices = nDIndex.getIndices();
        long j = PyTorchLibrary.LIB.torchIndexInit(indices.size());
        try {
            ListIterator listIterator = indices.listIterator();
            while (listIterator.hasNext()) {
                if (listIterator.nextIndex() == nDIndex.getEllipsisIndex()) {
                    PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(j, true);
                }
                NDIndexSlice nDIndexSlice = (NDIndexElement) listIterator.next();
                if (nDIndexSlice instanceof NDIndexNull) {
                    PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(j, false);
                } else if (nDIndexSlice instanceof NDIndexSlice) {
                    Long min = nDIndexSlice.getMin();
                    Long max = nDIndexSlice.getMax();
                    Long step = nDIndexSlice.getStep();
                    PyTorchLibrary.LIB.torchIndexAppendSlice(j, min == null ? -1L : min.longValue(), max == null ? -1L : max.longValue(), step == null ? 1L : step.longValue(), ((min == null ? 1 : NULL_PTR) * 2) + (max == null ? 1 : NULL_PTR));
                } else if (nDIndexSlice instanceof NDIndexAll) {
                    PyTorchLibrary.LIB.torchIndexAppendSlice(j, -1L, -1L, 1L, 3);
                } else if (nDIndexSlice instanceof NDIndexFixed) {
                    PyTorchLibrary.LIB.torchIndexAppendFixed(j, ((NDIndexFixed) nDIndexSlice).getIndex());
                } else if (nDIndexSlice instanceof NDIndexBooleans) {
                    PyTorchLibrary.LIB.torchIndexAppendArray(j, ((Long) ((PtNDArray) ((NDIndexBooleans) nDIndexSlice).getIndex()).getHandle()).longValue());
                } else if (nDIndexSlice instanceof NDIndexTake) {
                    PtNDArray ptNDArray3 = (PtNDArray) ((NDIndexTake) nDIndexSlice).getIndex();
                    if (ptNDArray3.getDataType() != DataType.INT64) {
                        ptNDArray3 = ptNDArray3.m136toType(DataType.INT64, true);
                    }
                    PyTorchLibrary.LIB.torchIndexAppendArray(j, ((Long) ptNDArray3.getHandle()).longValue());
                } else if (nDIndexSlice instanceof NDIndexPick) {
                    pick(ptNDArray, ptNDArray.m138getManager().mo177from(((NDIndexFullPick) NDIndexFullPick.fromIndex(nDIndex, ptNDArray.getShape()).get()).getIndices()), r0.getAxis());
                    PyTorchLibrary.LIB.torchDeleteIndex(j);
                    return;
                }
            }
            if (indices.size() == nDIndex.getEllipsisIndex()) {
                PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(j, true);
            }
            PyTorchLibrary.LIB.torchIndexAdvPut(((Long) ptNDArray.getHandle()).longValue(), j, ((Long) ptNDArray2.getHandle()).longValue());
            PyTorchLibrary.LIB.torchDeleteIndex(j);
        } catch (Throwable th) {
            PyTorchLibrary.LIB.torchDeleteIndex(j);
            throw th;
        }
    }

    public static void indexSet(PtNDArray ptNDArray, PtNDArray ptNDArray2, long[] jArr, long[] jArr2, long[] jArr3) {
        PyTorchLibrary.LIB.torchIndexPut(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), jArr, jArr2, jArr3);
    }

    public static void set(PtNDArray ptNDArray, ByteBuffer byteBuffer) {
        PyTorchLibrary.LIB.torchSet(((Long) ptNDArray.getHandle()).longValue(), byteBuffer);
    }

    public static PtNDArray gather(PtNDArray ptNDArray, PtNDArray ptNDArray2, long j) {
        if (ptNDArray2.getDataType() != DataType.INT64) {
            ptNDArray2 = ptNDArray2.m136toType(DataType.INT64, true);
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchGather(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), j, false));
    }

    public static PtNDArray take(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDManager ptNDManager) {
        if (ptNDArray2.getDataType() != DataType.INT64) {
            ptNDArray2 = ptNDArray2.m136toType(DataType.INT64, true);
        }
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchTake(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray put(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3) {
        if (ptNDArray2.getDataType() != DataType.INT64) {
            ptNDArray2 = ptNDArray2.m136toType(DataType.INT64, true);
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchPut(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue()));
    }

    public static PtNDArray scatter(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, int i) {
        if (ptNDArray2.getDataType() != DataType.INT64) {
            ptNDArray2 = ptNDArray2.m136toType(DataType.INT64, true);
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchScatter(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue(), i));
    }

    public static PtNDArray pick(PtNDArray ptNDArray, PtNDArray ptNDArray2, long j) {
        Shape shape = ptNDArray2.getShape();
        Shape shape2 = ptNDArray.getShape();
        int dimension = shape.dimension();
        int dimension2 = shape2.dimension();
        if (dimension != dimension2) {
            int i = NULL_PTR;
            while (true) {
                if (i >= dimension2 - dimension) {
                    break;
                }
                if (shape.equals(shape2.slice(i, dimension))) {
                    long[] shape3 = shape.getShape();
                    long[] jArr = new long[dimension2];
                    Arrays.fill(jArr, NULL_PTR, i, 1L);
                    Arrays.fill(jArr, i, i + shape3.length, shape3[i]);
                    Arrays.fill(jArr, i + shape3.length, dimension2, 1L);
                    shape = new Shape(jArr);
                    break;
                }
                i++;
            }
            if (shape.equals(ptNDArray2.getShape())) {
                throw new IllegalArgumentException("expand shape failed! Cannot expand from " + shape + "to " + shape2);
            }
            ptNDArray2 = ptNDArray2.m50reshape(shape);
        }
        if (ptNDArray2.getDataType() != DataType.INT64) {
            ptNDArray2 = ptNDArray2.m136toType(DataType.INT64, true);
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchGather(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), j, false));
    }

    public static PtNDArray where(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3) {
        return new PtNDArray(ptNDArray2.m138getManager(), PyTorchLibrary.LIB.torchWhere(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue()));
    }

    public static PtNDArray booleanMask(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMaskedSelect(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void booleanMaskSet(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3) {
        PyTorchLibrary.LIB.torchMaskedPut(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue());
    }

    public static PtNDArray getItem(PtNDArray ptNDArray, long[] jArr, PtNDManager ptNDManager) {
        return jArr.length == 1 ? new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchGetItem(((Long) ptNDArray.getHandle()).longValue(), jArr[NULL_PTR])) : new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchGetItem(((Long) ptNDArray.getHandle()).longValue(), jArr));
    }

    public static PtNDArray clone(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.tensorClone(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray reshape(PtNDArray ptNDArray, long[] jArr) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchReshape(((Long) ptNDArray.getHandle()).longValue(), jArr));
    }

    public static PtNDArray stack(PtNDArray[] ptNDArrayArr, int i) {
        return new PtNDArray(ptNDArrayArr[NULL_PTR].m138getManager(), PyTorchLibrary.LIB.torchStack(Arrays.stream(ptNDArrayArr).mapToLong((v0) -> {
            return v0.getHandle();
        }).toArray(), i));
    }

    public static PtNDArray cat(PtNDArray[] ptNDArrayArr, long j) {
        return new PtNDArray(ptNDArrayArr[NULL_PTR].m138getManager(), PyTorchLibrary.LIB.torchCat(Arrays.stream(ptNDArrayArr).mapToLong((v0) -> {
            return v0.getHandle();
        }).toArray(), j));
    }

    public static PtNDArray tile(PtNDArray ptNDArray, long[] jArr) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchRepeat(((Long) ptNDArray.getHandle()).longValue(), jArr));
    }

    public static PtNDArray repeat(PtNDArray ptNDArray, long j, long j2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchRepeatInterleave(((Long) ptNDArray.getHandle()).longValue(), j, j2));
    }

    public static PtNDArray softmax(PtNDArray ptNDArray, long j, DataType dataType) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSoftmax(((Long) ptNDArray.getHandle()).longValue(), j, dataType.ordinal()));
    }

    public static PtNDArray logSoftmax(PtNDArray ptNDArray, long j, DataType dataType) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLogSoftmax(((Long) ptNDArray.getHandle()).longValue(), j, dataType.ordinal()));
    }

    public static PtNDArray argMax(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchArgMax(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray argMax(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchArgMax(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static NDList topK(PtNDArray ptNDArray, long j, long j2, boolean z, boolean z2) {
        long[] jArr = PyTorchLibrary.LIB.torchTopK(((Long) ptNDArray.getHandle()).longValue(), j, j2, z, z2);
        NDList nDList = new NDList(jArr.length);
        int length = jArr.length;
        for (int i = NULL_PTR; i < length; i++) {
            nDList.add(new PtNDArray(ptNDArray.m138getManager(), jArr[i]));
        }
        return nDList;
    }

    public static PtNDArray argMin(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchArgMin(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray argMin(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchArgMin(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static PtNDArray argSort(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchArgSort(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static PtNDArray sort(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSort(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static PtNDArray permute(PtNDArray ptNDArray, long[] jArr) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchPermute(((Long) ptNDArray.getHandle()).longValue(), jArr));
    }

    public static PtNDArray flip(PtNDArray ptNDArray, long[] jArr) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchFlip(((Long) ptNDArray.getHandle()).longValue(), jArr));
    }

    public static PtNDArray transpose(PtNDArray ptNDArray, long j, long j2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchTranspose(((Long) ptNDArray.getHandle()).longValue(), j, j2));
    }

    public static boolean contentEqual(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return PyTorchLibrary.LIB.contentEqual(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray add(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchAdd(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void addi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchAddi(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray sub(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSub(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void subi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchSubi(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray mul(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMul(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void muli(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchMuli(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray div(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchTrueDivide(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void divi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchTrueDividei(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray remainder(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchRemainder(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void remainderi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchRemainderi(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray pow(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchPow(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static void powi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchPowi(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue());
    }

    public static PtNDArray sign(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSign(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static void signi(PtNDArray ptNDArray) {
        PyTorchLibrary.LIB.torchSigni(((Long) ptNDArray.getHandle()).longValue());
    }

    public static PtNDArray logicalAnd(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLogicalAnd(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray logicalOr(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLogicalOr(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray logicalXor(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLogicalXor(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray logicalNot(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLogicalNot(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray matmul(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMatmul(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray bmm(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchBmm(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray xlogy(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchXLogY(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray dot(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getShape().dimension() == 1 ? new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchDot(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue())) : new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMatmul(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray max(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMaximum(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray max(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMax(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray max(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMax(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static PtNDArray min(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMinimum(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray min(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMin(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray min(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMin(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static NDList median(PtNDArray ptNDArray, long j, boolean z) {
        long[] jArr = PyTorchLibrary.LIB.torchMedian(((Long) ptNDArray.getHandle()).longValue(), j, z);
        return new NDList(new NDArray[]{new PtNDArray(ptNDArray.m138getManager(), jArr[NULL_PTR]), new PtNDArray(ptNDArray.m138getManager(), jArr[1])});
    }

    public static PtNDArray mean(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMean(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray mean(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchMean(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static PtNDArray rot90(PtNDArray ptNDArray, int i, int[] iArr) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchRot90(((Long) ptNDArray.getHandle()).longValue(), i, Arrays.stream(iArr).mapToLong(i2 -> {
            return i2;
        }).toArray()));
    }

    public static PtNDArray sum(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSum(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray sum(PtNDArray ptNDArray, long[] jArr, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSum(((Long) ptNDArray.getHandle()).longValue(), jArr, z));
    }

    public static PtNDArray cumProd(PtNDArray ptNDArray, long j, DataType dataType) {
        int i = -1;
        if (dataType != null) {
            i = dataType.ordinal();
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchCumProd(((Long) ptNDArray.getHandle()).longValue(), j, i));
    }

    public static PtNDArray prod(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchProd(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray prod(PtNDArray ptNDArray, long j, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchProd(((Long) ptNDArray.getHandle()).longValue(), j, z));
    }

    public static PtNDArray cumSum(PtNDArray ptNDArray, long j) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchCumSum(((Long) ptNDArray.getHandle()).longValue(), j));
    }

    public static PtNDArray oneHot(PtNDArray ptNDArray, int i, DataType dataType) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNOneHot(((Long) ptNDArray.m136toType(DataType.INT64, false).getHandle()).longValue(), i)).m136toType(dataType, false);
    }

    public static NDList split(PtNDArray ptNDArray, long j, long j2) {
        long[] jArr = PyTorchLibrary.LIB.torchSplit(((Long) ptNDArray.getHandle()).longValue(), j, j2);
        NDList nDList = new NDList();
        int length = jArr.length;
        for (int i = NULL_PTR; i < length; i++) {
            nDList.add(new PtNDArray(ptNDArray.m138getManager(), jArr[i]));
        }
        return nDList;
    }

    public static NDList split(PtNDArray ptNDArray, long[] jArr, long j) {
        long[] jArr2 = PyTorchLibrary.LIB.torchSplit(((Long) ptNDArray.getHandle()).longValue(), jArr, j);
        NDList nDList = new NDList();
        int length = jArr2.length;
        for (int i = NULL_PTR; i < length; i++) {
            nDList.add(new PtNDArray(ptNDArray.m138getManager(), jArr2[i]));
        }
        return nDList;
    }

    public static PtNDArray squeeze(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSqueeze(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray squeeze(PtNDArray ptNDArray, long j) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSqueeze(((Long) ptNDArray.getHandle()).longValue(), j));
    }

    public static PtNDArray unsqueeze(PtNDArray ptNDArray, long j) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchUnsqueeze(((Long) ptNDArray.getHandle()).longValue(), j));
    }

    public static NDList unique(PtNDArray ptNDArray, Integer num, boolean z, boolean z2, boolean z3) {
        long[] jArr = num == null ? PyTorchLibrary.LIB.torchUnique(((Long) ptNDArray.getHandle()).longValue(), -1L, z, z2, z3) : PyTorchLibrary.LIB.torchUnique(((Long) ptNDArray.getHandle()).longValue(), Integer.valueOf(Math.floorMod(num.intValue(), ptNDArray.getShape().dimension())).intValue(), z, z2, z3);
        NDList nDList = new NDList(jArr.length);
        long[] jArr2 = jArr;
        int length = jArr2.length;
        for (int i = NULL_PTR; i < length; i++) {
            nDList.add(new PtNDArray(ptNDArray.m138getManager(), jArr2[i]));
        }
        return nDList;
    }

    public static PtNDArray flatten(PtNDArray ptNDArray, long j, long j2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchFlatten(((Long) ptNDArray.getHandle()).longValue(), j, j2));
    }

    public static PtNDArray fft(PtNDArray ptNDArray, long j, long j2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchFft(((Long) ptNDArray.getHandle()).longValue(), j, j2));
    }

    public static PtNDArray stft(PtNDArray ptNDArray, long j, long j2, PtNDArray ptNDArray2, boolean z, boolean z2, boolean z3) {
        long j3 = PyTorchLibrary.LIB.torchStft(((Long) ptNDArray.getHandle()).longValue(), j, j2, ((Long) ptNDArray2.getHandle()).longValue(), z, z2, z3);
        if (j3 == -1) {
            throw new UnsupportedOperationException("real() is not supported.");
        }
        return new PtNDArray(ptNDArray.m138getManager(), j3);
    }

    public static PtNDArray real(PtNDArray ptNDArray) {
        long j = PyTorchLibrary.LIB.torchViewAsReal(((Long) ptNDArray.getHandle()).longValue());
        if (j == -1) {
            throw new UnsupportedOperationException("real() is not supported.");
        }
        return new PtNDArray(ptNDArray.m138getManager(), j);
    }

    public static PtNDArray complex(PtNDArray ptNDArray) {
        long j = PyTorchLibrary.LIB.torchViewAsComplex(((Long) ptNDArray.getHandle()).longValue());
        if (j == -1) {
            throw new UnsupportedOperationException("complex() is not supported.");
        }
        return new PtNDArray(ptNDArray.m138getManager(), j);
    }

    public static PtNDArray abs(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchAbs(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray square(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSquare(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray floor(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchFloor(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray ceil(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchCeil(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray round(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchRound(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray trunc(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchTrunc(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray clip(PtNDArray ptNDArray, Number number, Number number2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchClamp(((Long) ptNDArray.getHandle()).longValue(), ((Long) ((PtNDArray) ptNDArray.m138getManager().create(number)).getHandle()).longValue(), ((Long) ((PtNDArray) ptNDArray.m138getManager().create(number2)).getHandle()).longValue()));
    }

    public static PtNDArray exp(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchExp(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray log(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLog(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray log10(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLog10(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray log2(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLog2(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray sin(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSin(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray cos(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchCos(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray tan(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchTan(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray asin(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchASin(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray acos(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchAcos(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray atan(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchAtan(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray sqrt(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSqrt(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray sinh(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSinh(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray cosh(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchCosh(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray tanh(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchTanh(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray sigmoid(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchSigmoid(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray all(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchAll(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray any(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchAny(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray none(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNone(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray eq(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchEq(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray neq(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNeq(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray gt(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchGt(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray gte(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchGte(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray lt(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLt(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray lte(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchLte(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue()));
    }

    public static PtNDArray neg(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNeg(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static void negi(PtNDArray ptNDArray) {
        PyTorchLibrary.LIB.torchNegi(((Long) ptNDArray.getHandle()).longValue());
    }

    public static PtNDArray isNaN(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchIsNaN(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray isInf(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchIsInf(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray randint(PtNDManager ptNDManager, long j, long j2, Shape shape, DataType dataType, Device device) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchRandint(j, j2, shape.getShape(), dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray randperm(PtNDManager ptNDManager, long j, DataType dataType, Device device) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchRandPerm(j, dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray normal(PtNDManager ptNDManager, double d, double d2, Shape shape, DataType dataType, Device device) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchNormal(d, d2, shape.getShape(), dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray uniform(PtNDManager ptNDManager, double d, double d2, Shape shape, DataType dataType, Device device) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.tensorUniform(d, d2, shape.getShape(), dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray eye(PtNDManager ptNDManager, int i, int i2, DataType dataType, Device device, SparseFormat sparseFormat) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchEye(i, i2, dataType.ordinal(), layoutMapper(sparseFormat, device), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, false));
    }

    public static PtNDArray hannWindow(PtNDManager ptNDManager, long j, boolean z, Device device) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.torchHannWindow(j, z, new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}));
    }

    public static PtNDArray erfinv(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchErfinv(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray inverse(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchInverse(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray interpolate(PtNDArray ptNDArray, long[] jArr, int i, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNInterpolate(((Long) ptNDArray.getHandle()).longValue(), jArr, i, z));
    }

    public static PtNDArray linear(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNLinear(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ptNDArray3 == null ? 0L : ((Long) ptNDArray3.getHandle()).longValue()));
    }

    public static PtNDArray embedding(PtNDArray ptNDArray, PtNDArray ptNDArray2, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNEmbedding(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), z));
    }

    public static PtNDArray relu(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNRelu(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray softPlus(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNSoftPlus(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray softSign(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNSoftSign(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray leakyRelu(PtNDArray ptNDArray, double d) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu(((Long) ptNDArray.getHandle()).longValue(), d));
    }

    public static PtNDArray elu(PtNDArray ptNDArray, double d) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNElu(((Long) ptNDArray.getHandle()).longValue(), d));
    }

    public static PtNDArray selu(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNSelu(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray gelu(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNGelu(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray convolution(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, Shape shape, Shape shape2, Shape shape3, int i) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNConvNd(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ptNDArray3 != null ? ((Long) ptNDArray3.getHandle()).longValue() : 0L, shape.getShape(), shape2.getShape(), shape3.getShape(), i));
    }

    public static PtNDArray batchNorm(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, PtNDArray ptNDArray4, PtNDArray ptNDArray5, boolean z, double d, double d2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNBatchNorm(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue(), ((Long) ptNDArray4.getHandle()).longValue(), ((Long) ptNDArray5.getHandle()).longValue(), z, d, d2));
    }

    public static PtNDArray layerNorm(PtNDArray ptNDArray, Shape shape, PtNDArray ptNDArray2, PtNDArray ptNDArray3, double d) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNLayerNorm(((Long) ptNDArray.getHandle()).longValue(), shape.getShape(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue(), d));
    }

    public static PtNDArray normalize(PtNDArray ptNDArray, double d, long j, double d2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNNormalize(((Long) ptNDArray.getHandle()).longValue(), d, j, d2));
    }

    public static PtNDArray dropout(PtNDArray ptNDArray, double d, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNDropout(((Long) ptNDArray.getHandle()).longValue(), d, z));
    }

    public static NDList rnn(PtNDArray ptNDArray, PtNDArray ptNDArray2, NDList nDList, boolean z, int i, RNN.Activation activation, double d, boolean z2, boolean z3, boolean z4) {
        PtNDManager m138getManager = ptNDArray.m138getManager();
        long[] jArr = PyTorchLibrary.LIB.torchNNRnn(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), nDList.stream().mapToLong(nDArray -> {
            return ((Long) ((PtNDArray) nDArray).getHandle()).longValue();
        }).toArray(), z, i, activation.ordinal(), d, z2, z3, z4);
        NDList nDList2 = new NDList();
        int length = jArr.length;
        for (int i2 = NULL_PTR; i2 < length; i2++) {
            nDList2.add(new PtNDArray(m138getManager, jArr[i2]));
        }
        return nDList2;
    }

    public static NDList gru(PtNDArray ptNDArray, PtNDArray ptNDArray2, NDList nDList, boolean z, int i, double d, boolean z2, boolean z3, boolean z4) {
        PtNDManager m138getManager = ptNDArray.m138getManager();
        long[] jArr = PyTorchLibrary.LIB.torchNNGru(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), nDList.stream().mapToLong(nDArray -> {
            return ((Long) ((PtNDArray) nDArray).getHandle()).longValue();
        }).toArray(), z, i, d, z2, z3, z4);
        NDList nDList2 = new NDList();
        int length = jArr.length;
        for (int i2 = NULL_PTR; i2 < length; i2++) {
            nDList2.add(new PtNDArray(m138getManager, jArr[i2]));
        }
        return nDList2;
    }

    public static NDList lstm(PtNDArray ptNDArray, NDList nDList, NDList nDList2, boolean z, int i, double d, boolean z2, boolean z3, boolean z4) {
        PtNDManager m138getManager = ptNDArray.m138getManager();
        long[] jArr = PyTorchLibrary.LIB.torchNNLstm(((Long) ptNDArray.getHandle()).longValue(), nDList.stream().mapToLong(nDArray -> {
            return ((Long) ((PtNDArray) nDArray).getHandle()).longValue();
        }).toArray(), nDList2.stream().mapToLong(nDArray2 -> {
            return ((Long) ((PtNDArray) nDArray2).getHandle()).longValue();
        }).toArray(), z, i, d, z2, z3, z4);
        NDList nDList3 = new NDList();
        int length = jArr.length;
        for (int i2 = NULL_PTR; i2 < length; i2++) {
            nDList3.add(new PtNDArray(m138getManager, jArr[i2]));
        }
        return nDList3;
    }

    public static PtNDArray avgPool(PtNDArray ptNDArray, Shape shape, Shape shape2, Shape shape3, boolean z, boolean z2) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNAvgPool(((Long) ptNDArray.getHandle()).longValue(), shape.getShape(), shape2.getShape(), shape3.getShape(), z, z2));
    }

    public static PtNDArray maxPool(PtNDArray ptNDArray, Shape shape, Shape shape2, Shape shape3, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNMaxPool(((Long) ptNDArray.getHandle()).longValue(), shape.getShape(), shape2.getShape(), shape3.getShape(), z));
    }

    public static PtNDArray adaptiveMaxPool(PtNDArray ptNDArray, Shape shape) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool(((Long) ptNDArray.getHandle()).longValue(), shape.getShape()));
    }

    public static PtNDArray adaptiveAvgPool(PtNDArray ptNDArray, Shape shape) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool(((Long) ptNDArray.getHandle()).longValue(), shape.getShape()));
    }

    public static PtNDArray lpPool(PtNDArray ptNDArray, double d, Shape shape, Shape shape2, boolean z) {
        if (ptNDArray.getShape().dimension() - 2 == 3) {
            throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine");
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNNLpPool(((Long) ptNDArray.getHandle()).longValue(), d, shape.getShape(), shape2.getShape(), z));
    }

    public static DataType getDataType(PtNDArray ptNDArray) {
        return DataType.values()[PyTorchLibrary.LIB.torchDType(((Long) ptNDArray.getHandle()).longValue())];
    }

    public static Device getDevice(PtNDArray ptNDArray) {
        int[] iArr = PyTorchLibrary.LIB.torchDevice(((Long) ptNDArray.getHandle()).longValue());
        return Device.of(PtDeviceType.fromDeviceType(iArr[NULL_PTR]), iArr[1]);
    }

    public static SparseFormat getSparseFormat(PtNDArray ptNDArray) {
        int i = PyTorchLibrary.LIB.torchLayout(((Long) ptNDArray.getHandle()).longValue());
        if (i == 0) {
            return SparseFormat.DENSE;
        }
        if (i == 1) {
            return SparseFormat.COO;
        }
        if (i != 2) {
            throw new UnsupportedOperationException("Unsupported data format");
        }
        logger.debug("MKLDNN layout is used!");
        return SparseFormat.DENSE;
    }

    public static Shape getShape(PtNDArray ptNDArray) {
        return new Shape(PyTorchLibrary.LIB.torchSizes(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static ByteBuffer getByteBuffer(PtNDArray ptNDArray) {
        if (!ptNDArray.getDevice().equals(Device.cpu())) {
            ptNDArray = ptNDArray.m137toDevice(Device.cpu(), false);
        }
        return ByteBuffer.wrap(PyTorchLibrary.LIB.torchDataPtr(((Long) ptNDArray.getHandle()).longValue())).order(ByteOrder.nativeOrder());
    }

    public static void deleteNDArray(long j) {
        PyTorchLibrary.LIB.torchDeleteTensor(j);
    }

    public static boolean requiresGrad(PtNDArray ptNDArray) {
        return PyTorchLibrary.LIB.torchRequiresGrad(((Long) ptNDArray.getHandle()).longValue());
    }

    public static String getGradientFunctionNames(PtNDArray ptNDArray) {
        return PyTorchLibrary.LIB.torchGradFnName(((Long) ptNDArray.getHandle()).longValue());
    }

    public static void attachGradient(PtNDArray ptNDArray, boolean z) {
        PyTorchLibrary.LIB.torchAttachGrad(((Long) ptNDArray.getHandle()).longValue(), z);
    }

    public static PtNDArray detachGradient(PtNDArray ptNDArray) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchDetachGrad(((Long) ptNDArray.getHandle()).longValue()));
    }

    public static PtNDArray getGradient(PtNDArray ptNDArray) {
        long j = PyTorchLibrary.LIB.torchGrad(((Long) ptNDArray.getHandle()).longValue());
        if (j == 0) {
            return null;
        }
        return new PtNDArray(ptNDArray.m138getManager(), j);
    }

    public static void backward(PtNDArray ptNDArray, PtNDArray ptNDArray2, boolean z, boolean z2) {
        PyTorchLibrary.LIB.torchBackward(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), z, z2);
    }

    public static void deleteModule(long j) {
        PyTorchLibrary.LIB.torchDeleteModule(j);
    }

    public static void setGraphExecutorOptimize(boolean z) {
        PyTorchLibrary.LIB.setGraphExecutorOptimize(z);
    }

    public static PtSymbolBlock loadModule(PtNDManager ptNDManager, Path path, boolean z, String[] strArr, String[] strArr2, boolean z2) {
        Device device = ptNDManager.getDevice();
        if ("mps".equals(device.getDeviceType())) {
            z = NULL_PTR;
        }
        logger.debug("mapLocation: {}", Boolean.valueOf(z));
        logger.debug("extraFileKeys: {}", Arrays.toString(strArr));
        return new PtSymbolBlock(ptNDManager, PyTorchLibrary.LIB.moduleLoad(path.toString(), new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, z, strArr, strArr2, z2));
    }

    public static PtSymbolBlock loadModule(PtNDManager ptNDManager, InputStream inputStream, boolean z, boolean z2) throws IOException {
        return new PtSymbolBlock(ptNDManager, loadModuleHandle(inputStream, ptNDManager.getDevice(), z, z2));
    }

    public static long loadModuleHandle(InputStream inputStream, Device device, boolean z, boolean z2) throws IOException {
        byte[] bArr = new byte[BYTE_LENGTH];
        long j = -1;
        if (z2) {
            j = new DataInputStream(inputStream).readLong();
        }
        if ("mps".equals(device.getDeviceType())) {
            z = NULL_PTR;
        }
        logger.debug("mapLocation: {}", Boolean.valueOf(z));
        return PyTorchLibrary.LIB.moduleLoad(inputStream, new int[]{PtDeviceType.toDeviceType(device), device.getDeviceId()}, z, bArr, j);
    }

    public static void writeModule(PtSymbolBlock ptSymbolBlock, OutputStream outputStream, boolean z) {
        PyTorchLibrary.LIB.moduleWrite(ptSymbolBlock.getHandle().longValue(), outputStream, new byte[BYTE_LENGTH], z);
    }

    public static NDList moduleGetParams(PtSymbolBlock ptSymbolBlock, PtNDManager ptNDManager) {
        long[] moduleGetParams = PyTorchLibrary.LIB.moduleGetParams(ptSymbolBlock.getHandle().longValue());
        String[] moduleGetParamNames = PyTorchLibrary.LIB.moduleGetParamNames(ptSymbolBlock.getHandle().longValue());
        NDList nDList = new NDList(moduleGetParams.length);
        for (int i = NULL_PTR; i < moduleGetParams.length; i++) {
            PtNDArray ptNDArray = new PtNDArray(ptNDManager, moduleGetParams[i]);
            ptNDArray.setName(moduleGetParamNames[i]);
            nDList.add(ptNDArray);
        }
        return nDList;
    }

    public static String[] getMethodNames(PtSymbolBlock ptSymbolBlock) {
        return PyTorchLibrary.LIB.moduleGetMethodNames(ptSymbolBlock.getHandle().longValue());
    }

    public static void enableInferenceMode(PtSymbolBlock ptSymbolBlock) {
        PyTorchLibrary.LIB.moduleEval(ptSymbolBlock.getHandle().longValue());
    }

    public static void enableTrainingMode(PtSymbolBlock ptSymbolBlock) {
        PyTorchLibrary.LIB.moduleTrain(ptSymbolBlock.getHandle().longValue());
    }

    public static void zeroGrad(PtNDArray ptNDArray) {
        PyTorchLibrary.LIB.zeroGrad(((Long) ptNDArray.getHandle()).longValue());
    }

    public static void adamUpdate(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, PtNDArray ptNDArray4, float f, float f2, float f3, float f4, float f5, float f6, float f7, float f8, boolean z) {
        PyTorchLibrary.LIB.adamUpdate(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ((Long) ptNDArray3.getHandle()).longValue(), ((Long) ptNDArray4.getHandle()).longValue(), f, f2, f3, f4, f5, f6, f7, f8, z);
    }

    public static void sgdUpdate(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, float f, float f2, float f3, float f4, float f5) {
        PyTorchLibrary.LIB.sgdUpdate(((Long) ptNDArray.getHandle()).longValue(), ((Long) ptNDArray2.getHandle()).longValue(), ptNDArray3 == null ? 0L : ((Long) ptNDArray3.getHandle()).longValue(), f, f2, f3, f4, f5);
    }

    public static int getLayout(PtNDArray ptNDArray) {
        return PyTorchLibrary.LIB.torchLayout(((Long) ptNDArray.getHandle()).longValue());
    }

    public static PtNDArray norm(PtNDArray ptNDArray, int i, int[] iArr, boolean z) {
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNorm(((Long) ptNDArray.getHandle()).longValue(), i, Arrays.stream(iArr).mapToLong(i2 -> {
            return i2;
        }).toArray(), z));
    }

    public static PtNDArray nonZeros(PtNDArray ptNDArray) {
        if (ptNDArray.isScalar()) {
            ptNDArray = (PtNDArray) ptNDArray.reshape(new long[]{-1});
        }
        return new PtNDArray(ptNDArray.m138getManager(), PyTorchLibrary.LIB.torchNonZeros(((Long) ptNDArray.getHandle()).longValue()));
    }
}
