package ai.djl.pytorch.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.types.DataType;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.translate.Translator;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/pytorch/engine/PtModel.class */
public class PtModel extends BaseModel {
    /* JADX INFO: Access modifiers changed from: package-private */
    public PtModel(String str, Device device) {
        super(str);
        this.manager = PtNDManager.getSystemManager().mo173newSubManager(Device.defaultIfNull(device));
        this.dataType = DataType.FLOAT32;
    }

    public void load(Path path, String str, Map<String, Object> map) throws IOException, MalformedModelException {
        this.modelDir = path.toAbsolutePath();
        if (str == null) {
            str = this.modelName;
        }
        if (this.block != null) {
            Path paramPathResolver = paramPathResolver(str, map);
            if (paramPathResolver == null) {
                throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name matchyour saved model file name.");
            }
            readParameters(paramPathResolver, map);
            return;
        }
        Path findModelFile = findModelFile(str);
        if (findModelFile == null) {
            findModelFile = findModelFile(this.modelDir.toFile().getName());
            if (findModelFile == null) {
                throw new FileNotFoundException(".pt file not found in: " + this.modelDir);
            }
        }
        this.block = JniUtils.loadModule(this.manager, findModelFile, this.manager.getDevice());
    }

    private Path findModelFile(String str) {
        Path resolve = this.modelDir.resolve(str);
        if (Files.notExists(resolve, new LinkOption[0]) || !Files.isRegularFile(resolve, new LinkOption[0])) {
            if (str.endsWith(".pt")) {
                return null;
            }
            resolve = this.modelDir.resolve(str + ".pt");
            if (Files.notExists(resolve, new LinkOption[0]) || !Files.isRegularFile(resolve, new LinkOption[0])) {
                return null;
            }
        }
        return resolve;
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        Initializer initializer = trainingConfig.getInitializer();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        this.block.setInitializer(initializer);
        return new Trainer(this, trainingConfig);
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new Predictor<>(this, translator, false);
    }

    public String[] getArtifactNames() {
        try {
            List<Path> list = (List) Files.walk(this.modelDir, new FileVisitOption[0]).filter(path -> {
                return Files.isRegularFile(path, new LinkOption[0]);
            }).collect(Collectors.toList());
            ArrayList arrayList = new ArrayList(list.size());
            for (Path path2 : list) {
                if (!path2.toFile().getName().endsWith(".pt")) {
                    arrayList.add(this.modelDir.relativize(path2).toString());
                }
            }
            return (String[]) arrayList.toArray(new String[0]);
        } catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void close() {
        this.manager.close();
    }
}
