package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.jni.JniUtils;
import java.util.Iterator;
import java.util.Stack;

/* loaded from: input_file:ai/djl/pytorch/engine/PtNDArrayIndexer.class */
public class PtNDArrayIndexer extends NDArrayIndexer {
    public NDArray get(NDArray nDArray, NDIndexFullPick nDIndexFullPick) {
        return JniUtils.pick((PtNDArray) nDArray, (PtNDArray) nDIndexFullPick.getIndices(), nDIndexFullPick.getAxis());
    }

    public NDArray get(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice) {
        PtNDArray index = JniUtils.index((PtNDArray) nDArray, nDIndexFullSlice.getMin(), nDIndexFullSlice.getMax(), nDIndexFullSlice.getStep());
        try {
            PtNDArray m43squeeze = index.m43squeeze(nDIndexFullSlice.getToSqueeze());
            if (index != null) {
                index.close();
            }
            return m43squeeze;
        } catch (Throwable th) {
            if (index != null) {
                try {
                    index.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, NDArray nDArray2) {
        Shape shape;
        Stack stack = new Stack();
        stack.add(nDArray2);
        stack.add(((NDArray) stack.peek()).toDevice(nDArray.getDevice(), false));
        Shape shape2 = nDIndexFullSlice.getShape();
        while (true) {
            shape = shape2;
            if (shape.size() <= nDArray2.size()) {
                break;
            } else {
                shape2 = shape.slice(1);
            }
        }
        stack.add(((NDArray) stack.peek()).reshape(shape));
        stack.add(((NDArray) stack.peek()).broadcast(nDIndexFullSlice.getShape()));
        JniUtils.indexSet((PtNDArray) nDArray, (PtNDArray) stack.peek(), nDIndexFullSlice.getMin(), nDIndexFullSlice.getMax(), nDIndexFullSlice.getStep());
        Iterator it = stack.iterator();
        while (it.hasNext()) {
            NDArray nDArray3 = (NDArray) it.next();
            if (nDArray3 != nDArray2) {
                nDArray3.close();
            }
        }
    }

    public void set(NDArray nDArray, NDIndexBooleans nDIndexBooleans, NDArray nDArray2) {
        NDArray index = nDIndexBooleans.getIndex();
        try {
            JniUtils.booleanMaskSet((PtNDArray) nDArray, (PtNDArray) nDArray2, (PtNDArray) index);
            if (index != null) {
                index.close();
            }
        } catch (Throwable th) {
            if (index != null) {
                try {
                    index.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, Number number) {
        set(nDArray, nDIndexFullSlice, nDArray.getManager().create(number));
    }
}
