/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.util;

import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
import org.apache.ignite.IgniteException;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;

public class MnistUtils {
    public static Stream<DenseVector> mnistAsStream(String imagesPath, String labelsPath, Random rnd, int cnt) throws IOException {
        FileInputStream isImages = new FileInputStream(imagesPath);
        FileInputStream isLabels = new FileInputStream(labelsPath);
        MnistUtils.read4Bytes(isImages);
        int numOfImages = MnistUtils.read4Bytes(isImages);
        int imgHeight = MnistUtils.read4Bytes(isImages);
        int imgWidth = MnistUtils.read4Bytes(isImages);
        MnistUtils.read4Bytes(isLabels);
        MnistUtils.read4Bytes(isLabels);
        int numOfPixels = imgHeight * imgWidth;
        double[][] vecs = new double[numOfImages][numOfPixels + 1];
        for (int imgNum = 0; imgNum < numOfImages; ++imgNum) {
            vecs[imgNum][numOfPixels] = isLabels.read();
            for (int p = 0; p < numOfPixels; ++p) {
                int c = 128 - isImages.read();
                vecs[imgNum][p] = (double)c / 128.0;
            }
        }
        List lst = Arrays.asList(vecs);
        Collections.shuffle(lst, rnd);
        isImages.close();
        isLabels.close();
        return lst.subList(0, cnt).stream().map(DenseVector::new);
    }

    public static List<MnistLabeledImage> mnistAsList(String imagesPath, String labelsPath, Random rnd, int cnt) throws IOException {
        return MnistUtils.mnistAsList(new FileInputStream(imagesPath), new FileInputStream(labelsPath), rnd, cnt);
    }

    public static List<MnistLabeledImage> mnistAsListFromResource(String imagesPath, String labelsPath, Random rnd, int cnt) throws IOException {
        return MnistUtils.mnistAsList(MnistUtils.class.getClassLoader().getResourceAsStream(imagesPath), MnistUtils.class.getClassLoader().getResourceAsStream(labelsPath), rnd, cnt);
    }

    private static List<MnistLabeledImage> mnistAsList(InputStream imageStream, InputStream lbStream, Random rnd, int cnt) throws IOException {
        ArrayList<MnistLabeledImage> res = new ArrayList<MnistLabeledImage>();
        MnistUtils.read4Bytes(imageStream);
        int numOfImages = MnistUtils.read4Bytes(imageStream);
        int imgHeight = MnistUtils.read4Bytes(imageStream);
        int imgWidth = MnistUtils.read4Bytes(imageStream);
        MnistUtils.read4Bytes(lbStream);
        MnistUtils.read4Bytes(lbStream);
        int numOfPixels = imgHeight * imgWidth;
        for (int imgNum = 0; imgNum < numOfImages; ++imgNum) {
            double[] pixels = new double[numOfPixels];
            for (int p = 0; p < numOfPixels; ++p) {
                pixels[p] = (float)(1.0 * (double)(imageStream.read() & 0xFF) / 255.0);
            }
            res.add(new MnistLabeledImage(pixels, lbStream.read()));
        }
        Collections.shuffle(res, rnd);
        return res.subList(0, cnt);
    }

    public static void asLIBSVM(String imagesPath, String labelsPath, String outPath, Random rnd, int cnt) throws IOException {
        try (FileWriter fos = new FileWriter(outPath);){
            MnistUtils.mnistAsStream(imagesPath, labelsPath, rnd, cnt).forEach(vec -> {
                try {
                    fos.write((int)vec.get(vec.size() - 1) + " ");
                    for (int i = 0; i < vec.size() - 1; ++i) {
                        double val = vec.get(i);
                        if (val == 0.0) continue;
                        fos.write(i + 1 + ":" + val + " ");
                    }
                    fos.write("\n");
                }
                catch (IOException e) {
                    throw new IgniteException("Error while converting to LIBSVM.");
                }
            });
        }
    }

    private static int read4Bytes(InputStream is) throws IOException {
        return is.read() << 24 | is.read() << 16 | is.read() << 8 | is.read();
    }

    public static class MnistLabeledImage
    extends MnistImage {
        private final int lb;

        public MnistLabeledImage(double[] pixels, int lb) {
            super(pixels);
            this.lb = lb;
        }

        public int getLabel() {
            return this.lb;
        }
    }

    public static class MnistImage {
        private final double[] pixels;

        public MnistImage(double[] pixels) {
            this.pixels = pixels;
        }

        public double[] getPixels() {
            return this.pixels;
        }
    }
}

