package ai.djl.pytorch.jni;

import ai.djl.util.Platform;
import ai.djl.util.Utils;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/pytorch/jni/LibUtils.class */
public final class LibUtils {
    private static final String LIB_NAME = "djl_torch";
    private static final String NATIVE_LIB_NAME = "torch";
    private static String version;
    private static String libtorchPath;
    private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
    private static final Pattern VERSION_PATTERN = Pattern.compile("(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)(-SNAPSHOT)?(-\\d+)?");

    private LibUtils() {
    }

    public static void loadLibrary() {
        Properties properties = new Properties();
        try {
            InputStream resourceAsStream = LibUtils.class.getResourceAsStream("/jnilib/pytorch.properties");
            try {
                properties.load(resourceAsStream);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
                String property = properties.getProperty("version");
                String property2 = properties.getProperty("jni_version");
                version = property2;
                if ("http://www.android.com/".equals(System.getProperty("java.vendor.url"))) {
                    System.loadLibrary(LIB_NAME);
                    return;
                }
                libtorchPath = getLibName(property, property2);
                logger.debug("Loading pytorch library from: {}", libtorchPath);
                if (System.getProperty("os.name").startsWith("Win")) {
                    loadWinDependencies(libtorchPath);
                }
                loadNativeLibrary(libtorchPath);
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalStateException("Cannot find pytorch property file", e);
        }
    }

    public static String getLibName(String str, String str2) {
        String findOverrideLibrary = findOverrideLibrary();
        if (findOverrideLibrary == null) {
            AtomicBoolean atomicBoolean = new AtomicBoolean(false);
            String findNativeLibrary = findNativeLibrary(atomicBoolean);
            if (findNativeLibrary == null) {
                throw new IllegalStateException("Native library not found");
            }
            findOverrideLibrary = copyJniLibraryFromClasspath(Paths.get(findNativeLibrary, new String[0]), str, str2, atomicBoolean.get());
        }
        return findOverrideLibrary;
    }

    public static String getVersion() {
        return version;
    }

    public static String getLibtorchPath() {
        return libtorchPath;
    }

    private static void loadWinDependencies(String str) {
        Path parent = Paths.get(str, new String[0]).getParent();
        if (parent == null) {
            throw new IllegalArgumentException("Invalid library path!");
        }
        HashSet hashSet = new HashSet(Arrays.asList("c10_cuda.dll", "torch.dll", "torch_cpu.dll", "torch_cuda.dll", "torch_cuda_cpp.dll", "torch_cuda_cu.dll", "fbgemm.dll"));
        try {
            Stream<Path> walk = Files.walk(parent, new FileVisitOption[0]);
            try {
                walk.filter(path -> {
                    String path = path.getFileName().toString();
                    return (hashSet.contains(path) || !Files.isRegularFile(path, new LinkOption[0]) || path.endsWith("djl_torch.dll") || path.startsWith("cudnn")) ? false : true;
                }).map(path2 -> {
                    return path2.toAbsolutePath().toString();
                }).forEach(System::load);
                loadNativeLibrary(parent.resolve("fbgemm.dll").toAbsolutePath().toString());
                loadNativeLibrary(parent.resolve("torch_cpu.dll").toAbsolutePath().toString());
                if (Files.exists(parent.resolve("c10_cuda.dll"), new LinkOption[0])) {
                    if (Files.exists(parent.resolve("cudnn64_8.dll"), new LinkOption[0])) {
                        loadNativeLibrary(parent.resolve("cudnn64_8.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("cudnn_ops_infer64_8.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("cudnn_ops_train64_8.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("cudnn_cnn_infer64_8.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("cudnn_cnn_train64_8.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("cudnn_adv_infer64_8.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("cudnn_adv_train64_8.dll").toAbsolutePath().toString());
                    } else if (Files.exists(parent.resolve("cudnn64_7.dll"), new LinkOption[0])) {
                        loadNativeLibrary(parent.resolve("cudnn64_7.dll").toAbsolutePath().toString());
                    }
                    loadNativeLibrary(parent.resolve("c10_cuda.dll").toAbsolutePath().toString());
                    if (Files.exists(parent.resolve("torch_cuda_cpp.dll"), new LinkOption[0])) {
                        loadNativeLibrary(parent.resolve("torch_cuda_cpp.dll").toAbsolutePath().toString());
                        loadNativeLibrary(parent.resolve("torch_cuda_cu.dll").toAbsolutePath().toString());
                    }
                    loadNativeLibrary(parent.resolve("torch_cuda.dll").toAbsolutePath().toString());
                }
                loadNativeLibrary(parent.resolve("torch.dll").toAbsolutePath().toString());
                if (walk != null) {
                    walk.close();
                }
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Folder not exist! " + parent, e);
        }
    }

    private static String findOverrideLibrary() {
        String findLibraryInPath;
        String str = System.getenv("PYTORCH_LIBRARY_PATH");
        if (str != null && (findLibraryInPath = findLibraryInPath(str)) != null) {
            return findLibraryInPath;
        }
        String property = System.getProperty("java.library.path");
        if (property != null) {
            return findLibraryInPath(property);
        }
        return null;
    }

    private static String findLibraryInPath(String str) {
        String[] split = str.split(File.pathSeparator);
        List<String> singletonList = Collections.singletonList(System.mapLibraryName(LIB_NAME));
        for (String str2 : split) {
            File file = new File(str2);
            if (file.exists()) {
                for (String str3 : singletonList) {
                    if (file.isFile() && file.getName().endsWith(str3)) {
                        return file.getAbsolutePath();
                    }
                    File file2 = new File(str2, str3);
                    if (file2.exists() && file2.isFile()) {
                        return file2.getAbsolutePath();
                    }
                }
            }
        }
        return null;
    }

    private static String copyJniLibraryFromClasspath(Path path, String str, String str2, boolean z) {
        String mapLibraryName = System.mapLibraryName(LIB_NAME);
        Platform fromSystem = Platform.fromSystem();
        String classifier = fromSystem.getClassifier();
        String flavor = fromSystem.getFlavor();
        if (z) {
            flavor = "cpu";
        }
        if (Files.exists(path.resolve("libstdc++.so.6"), new LinkOption[0])) {
            flavor = flavor + "-precxx11";
            logger.info("Using precxx11 jnilib.");
        }
        Path resolve = path.resolve(str + '-' + flavor + '-' + mapLibraryName);
        if (Files.exists(resolve, new LinkOption[0])) {
            return resolve.toAbsolutePath().toString();
        }
        if (!version.startsWith(str2)) {
            downloadJniLib(path, resolve, str, classifier, flavor, mapLibraryName);
            return resolve.toAbsolutePath().toString();
        }
        String str3 = "/jnilib/" + classifier + '/' + flavor + '/' + mapLibraryName;
        logger.info("Extracting {} to cache ...", str3);
        try {
            try {
                InputStream resourceAsStream = LibUtils.class.getResourceAsStream(str3);
                try {
                    if (resourceAsStream == null) {
                        throw new IllegalStateException("PyTorch jni not found: " + str3);
                    }
                    Path createTempFile = Files.createTempFile(path, "jni", "tmp", new FileAttribute[0]);
                    Files.copy(resourceAsStream, createTempFile, StandardCopyOption.REPLACE_EXISTING);
                    Utils.moveQuietly(createTempFile, resolve);
                    String path2 = resolve.toAbsolutePath().toString();
                    if (resourceAsStream != null) {
                        resourceAsStream.close();
                    }
                    if (createTempFile != null) {
                        Utils.deleteQuietly(createTempFile);
                    }
                    return path2;
                } catch (Throwable th) {
                    if (resourceAsStream != null) {
                        try {
                            resourceAsStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (IOException e) {
                throw new IllegalStateException("Cannot copy jni files", e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                Utils.deleteQuietly((Path) null);
            }
            throw th3;
        }
    }

    private static synchronized String findNativeLibrary(AtomicBoolean atomicBoolean) {
        String str = System.getenv("PYTORCH_VERSION");
        if (str == null) {
            str = System.getProperty("PYTORCH_VERSION");
        }
        if (str != null) {
            version = str;
            return downloadPyTorch(Platform.fromSystem(str), atomicBoolean);
        }
        try {
            Enumeration<URL> resources = Thread.currentThread().getContextClassLoader().getResources("native/lib/pytorch.properties");
            if (!resources.hasMoreElements()) {
                return downloadPyTorch(Platform.fromSystem(version), atomicBoolean);
            }
            Platform fromSystem = Platform.fromSystem();
            Platform platform = null;
            Platform platform2 = null;
            while (true) {
                try {
                    if (!resources.hasMoreElements()) {
                        break;
                    }
                    Platform fromUrl = Platform.fromUrl(resources.nextElement());
                    if (fromUrl.isPlaceholder()) {
                        platform2 = fromUrl;
                    } else if (fromUrl.matches(fromSystem, false)) {
                        platform = fromUrl;
                        break;
                    }
                } catch (IOException e) {
                    throw new IllegalStateException("Failed to read PyTorch native library jar properties", e);
                }
            }
            if (platform != null) {
                if ("cpu".equals(platform.getFlavor())) {
                    atomicBoolean.set(true);
                }
                return copyNativeLibraryFromClasspath(platform);
            }
            if (platform2 != null) {
                return downloadPyTorch(platform2, atomicBoolean);
            }
            throw new IllegalStateException("Your PyTorch native library jar does not match your operating system. Make sure the Maven Dependency Classifier matches your system type.");
        } catch (IOException e2) {
            logger.warn("", e2);
            return null;
        }
    }

    /* JADX WARN: Removed duplicated region for block: B:33:0x0160 A[EXC_TOP_SPLITTER, SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private static java.lang.String copyNativeLibraryFromClasspath(ai.djl.util.Platform r7) {
        /*
            Method dump skipped, instructions count: 437
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: ai.djl.pytorch.jni.LibUtils.copyNativeLibraryFromClasspath(ai.djl.util.Platform):java.lang.String");
    }

    private static void loadNativeLibrary(String str) {
        String property = System.getProperty("ai.djl.pytorch.native_helper");
        if (property != null && !property.isEmpty()) {
            try {
                Class.forName(property).getDeclaredMethod("load", String.class).invoke(null, str);
            } catch (ReflectiveOperationException e) {
                throw new IllegalArgumentException("Invalid native_helper: " + property, e);
            }
        }
        System.load(str);
    }

    private static String downloadPyTorch(Platform platform, AtomicBoolean atomicBoolean) {
        version = platform.getVersion();
        String flavor = platform.getFlavor();
        String classifier = platform.getClassifier();
        String osPrefix = platform.getOsPrefix();
        if (Boolean.getBoolean("PYTORCH_PRECXX11") || Boolean.parseBoolean(System.getenv("PYTORCH_PRECXX11"))) {
            flavor = flavor + "-precxx11";
        }
        String mapLibraryName = System.mapLibraryName(NATIVE_LIB_NAME);
        Path engineCacheDir = Utils.getEngineCacheDir("pytorch");
        Path resolve = engineCacheDir.resolve(version + '-' + flavor + '-' + classifier);
        logger.debug("Using cache dir: {}", resolve);
        if (Files.exists(resolve.resolve(mapLibraryName), new LinkOption[0])) {
            return resolve.toAbsolutePath().toString();
        }
        Matcher matcher = VERSION_PATTERN.matcher(version);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Unexpected version: " + version);
        }
        String str = "https://publish.djl.ai/pytorch-" + matcher.group(1);
        try {
            try {
                InputStream openStream = new URL(str + "/files.txt").openStream();
                try {
                    Files.createDirectories(engineCacheDir, new FileAttribute[0]);
                    List<String> readLines = Utils.readLines(openStream);
                    if (!readLines.contains(flavor + '/' + osPrefix + "/native/lib/" + mapLibraryName + ".gz")) {
                        if (!flavor.startsWith("cu")) {
                            throw new IOException("No matching flavor for " + osPrefix + " found: " + flavor);
                        }
                        logger.warn("No matching cuda flavor for {} found: {}.", osPrefix, flavor);
                        flavor = "cpu";
                        atomicBoolean.set(true);
                        resolve = engineCacheDir.resolve(version + '-' + flavor + '-' + classifier);
                        if (Files.exists(resolve.resolve(mapLibraryName), new LinkOption[0])) {
                            String path = resolve.toAbsolutePath().toString();
                            if (openStream != null) {
                                openStream.close();
                            }
                            return path;
                        }
                    }
                    Path createTempDirectory = Files.createTempDirectory(engineCacheDir, "tmp", new FileAttribute[0]);
                    for (String str2 : readLines) {
                        if (str2.startsWith(flavor + '/' + osPrefix + '/')) {
                            URL url = new URL(str + '/' + str2);
                            String decode = URLDecoder.decode(str2.substring(str2.lastIndexOf(47) + 1, str2.length() - 3), "UTF-8");
                            logger.info("Downloading {} ...", url);
                            GZIPInputStream gZIPInputStream = new GZIPInputStream(url.openStream());
                            try {
                                Files.copy(gZIPInputStream, createTempDirectory.resolve(decode), StandardCopyOption.REPLACE_EXISTING);
                                gZIPInputStream.close();
                            } catch (Throwable th) {
                                try {
                                    gZIPInputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                                throw th;
                            }
                        }
                    }
                    Utils.moveQuietly(createTempDirectory, resolve);
                    String path2 = resolve.toAbsolutePath().toString();
                    if (openStream != null) {
                        openStream.close();
                    }
                    if (createTempDirectory != null) {
                        Utils.deleteQuietly(createTempDirectory);
                    }
                    return path2;
                } catch (Throwable th3) {
                    if (openStream != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } finally {
                if (0 != 0) {
                    Utils.deleteQuietly((Path) null);
                }
            }
        } catch (IOException e) {
            throw new IllegalStateException("Failed to download PyTorch native library", e);
        }
    }

    private static void downloadJniLib(Path path, Path path2, String str, String str2, String str3, String str4) {
        Matcher matcher = VERSION_PATTERN.matcher(str);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Unexpected djl version: " + str);
        }
        String group = matcher.group(1);
        Matcher matcher2 = VERSION_PATTERN.matcher(version);
        if (!matcher2.matches()) {
            throw new IllegalArgumentException("Unexpected pytorch version: " + version);
        }
        StringBuilder sb = new StringBuilder(100);
        sb.append("https://publish.djl.ai/pytorch-").append(matcher2.group(1)).append("/jnilib/");
        if (str3.contains("-precxx11")) {
            str3 = str3.substring(0, str3.length() - 9);
            sb.append("precxx11/");
        }
        sb.append(group).append('/').append(str2).append('/').append(str3).append('/').append(str4);
        logger.info("Downloading jni {} to cache ...", sb);
        try {
            try {
                InputStream openStream = new URL(sb.toString()).openStream();
                try {
                    Path createTempFile = Files.createTempFile(path, "jni", "tmp", new FileAttribute[0]);
                    Files.copy(openStream, createTempFile, StandardCopyOption.REPLACE_EXISTING);
                    Utils.moveQuietly(createTempFile, path2);
                    if (openStream != null) {
                        openStream.close();
                    }
                    if (createTempFile != null) {
                        Utils.deleteQuietly(createTempFile);
                    }
                } catch (Throwable th) {
                    if (openStream != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (IOException e) {
                throw new IllegalStateException("Cannot download jni files", e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                Utils.deleteQuietly((Path) null);
            }
            throw th3;
        }
    }
}
