/*
 * Decompiled with CFR 0.152.
 */
package deepwater.datasets;

import deepwater.datasets.ImageDataSet;
import deepwater.datasets.Pair;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;

public class MNISTImageDataset
extends ImageDataSet {
    public static final Map<String, String> Resources = MNISTImageDataset.fillResources();
    private String labelFileName;
    private String imageFileName;
    private static final int MAGIC_OFFSET = 0;
    private static final int OFFSET_SIZE = 4;
    private static final int LABEL_MAGIC = 2049;
    private static final int IMAGE_MAGIC = 2051;
    private static final int NUMBER_ITEMS_OFFSET = 4;
    private static final int ITEMS_SIZE = 4;
    private static final int NUMBER_OF_ROWS_OFFSET = 8;
    private static final int ROWS_SIZE = 4;
    public static final int ROWS = 28;
    private static final int NUMBER_OF_COLUMNS_OFFSET = 12;
    private static final int COLUMNS_SIZE = 4;
    public static final int COLUMNS = 28;
    private static final int IMAGE_OFFSET = 16;
    private static final int IMAGE_SIZE = 784;

    private static Map<String, String> fillResources() {
        HashMap<String, String> resources = new HashMap<String, String>();
        resources.put("train_images", "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz");
        resources.put("in_images", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz");
        resources.put("test_images", "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz");
        resources.put("test_labels", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz");
        return Collections.unmodifiableMap(resources);
    }

    public MNISTImageDataset() {
        super(28, 28, 1, 10);
    }

    public List<Pair<Integer, float[]>> loadImages() throws IOException {
        return this.loadImages(this.imageFileName, this.labelFileName);
    }

    @Override
    public List<Pair<Integer, float[]>> loadImages(String ... filenames) throws IOException {
        int read;
        assert (filenames.length % 2 == 0) : "expected image and label";
        ArrayList<Pair<Integer, float[]>> images = new ArrayList<Pair<Integer, float[]>>();
        ByteArrayOutputStream labelBuffer = new ByteArrayOutputStream();
        ByteArrayOutputStream imageBuffer = new ByteArrayOutputStream();
        GZIPInputStream labelInputStream = new GZIPInputStream(new FileInputStream(filenames[1]));
        GZIPInputStream imageInputStream = new GZIPInputStream(new FileInputStream(filenames[0]));
        byte[] buffer = new byte[16384];
        while ((read = ((InputStream)labelInputStream).read(buffer, 0, buffer.length)) != -1) {
            labelBuffer.write(buffer, 0, read);
        }
        labelBuffer.flush();
        while ((read = ((InputStream)imageInputStream).read(buffer, 0, buffer.length)) != -1) {
            imageBuffer.write(buffer, 0, read);
        }
        imageBuffer.flush();
        byte[] labelBytes = labelBuffer.toByteArray();
        byte[] imageBytes = imageBuffer.toByteArray();
        byte[] labelMagic = Arrays.copyOfRange(labelBytes, 0, 4);
        byte[] imageMagic = Arrays.copyOfRange(imageBytes, 0, 4);
        int magic = ByteBuffer.wrap(labelMagic).getInt();
        if (magic != 2049) {
            throw new IOException("Bad magic number in label file got " + magic + "instead of " + 2049);
        }
        if (ByteBuffer.wrap(imageMagic).getInt() != 2051) {
            throw new IOException("Bad magic number in image file!");
        }
        int numberOfLabels = ByteBuffer.wrap(Arrays.copyOfRange(labelBytes, 4, 8)).getInt();
        int numberOfImages = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, 4, 8)).getInt();
        if (numberOfImages != numberOfLabels) {
            throw new IOException("The number of labels and images do not match!");
        }
        int numRows = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, 8, 12)).getInt();
        int numCols = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, 12, 16)).getInt();
        if (numRows != 28 && numCols != 28) {
            throw new IOException("Bad image. Rows and columns do not equal 28x28");
        }
        for (int i = 0; i < numberOfLabels; ++i) {
            byte label = labelBytes[8 + i];
            byte[] imageData = Arrays.copyOfRange(imageBytes, i * 784 + 16, i * 784 + 16 + 784);
            float[] imageDataFloat = new float[784];
            int p = 0;
            for (int j = 0; j < imageData.length; ++j) {
                float result = imageData[j] & 0xFF;
                imageDataFloat[p] = result = (float)((double)result * 0.00392156862745098);
                ++p;
            }
            assert (p == 784) : "Expected: 784 GOT: " + p;
            images.add(new Pair<Integer, float[]>(Integer.valueOf(label), imageDataFloat));
        }
        return images;
    }
}

