/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.image.loader;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.EqualizeHistTransform;
import org.datavec.image.transform.ImageTransform;
import org.eclipse.deeplearning4j.resources.DataSetResource;
import org.eclipse.deeplearning4j.resources.ResourceDataSets;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CifarLoader
extends NativeImageLoader
implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(CifarLoader.class);
    public static final int NUM_TRAIN_IMAGES = 50000;
    public static final int NUM_TEST_IMAGES = 10000;
    public static final int NUM_LABELS = 10;
    public static final int HEIGHT = 32;
    public static final int WIDTH = 32;
    public static final int CHANNELS = 3;
    public static final boolean DEFAULT_USE_SPECIAL_PREPROC = false;
    public static final boolean DEFAULT_SHUFFLE = true;
    private static final int BYTEFILELEN = 3073;
    private static final String[] TRAINFILENAMES = new String[]{"data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch5.bin"};
    private static final String TESTFILENAME = "test_batch.bin";
    private static final String labelFileName = "batches.meta.txt";
    private static final int numToConvertDS = 10000;
    protected final File fullDir;
    protected final File meanVarPath;
    protected final String trainFilesSerialized;
    protected final String testFilesSerialized;
    protected InputStream inputStream;
    protected InputStream trainInputStream;
    protected InputStream testInputStream;
    protected List<String> labels = new ArrayList<String>();
    public static Map<String, String> cifarDataMap = new HashMap<String, String>();
    protected boolean train;
    protected boolean useSpecialPreProcessCifar;
    protected long seed;
    protected boolean shuffle = true;
    protected int numExamples = 0;
    protected double uMean = 0.0;
    protected double uStd = 0.0;
    protected double vMean = 0.0;
    protected double vStd = 0.0;
    protected boolean meanStdStored = false;
    protected int loadDSIndex = 0;
    protected DataSet loadDS = new DataSet();
    protected int fileNum = 0;
    private static DataSetResource cifar = ResourceDataSets.cifar10();

    private static File getDefaultDirectory() {
        return cifar.localCacheDirectory();
    }

    public CifarLoader() {
        this(true);
    }

    public CifarLoader(boolean train) {
        this(train, null);
    }

    public CifarLoader(boolean train, File fullPath) {
        this(32, 32, 3, null, train, false, fullPath, System.currentTimeMillis(), true);
    }

    public CifarLoader(int height, int width, int channels, boolean train, boolean useSpecialPreProcessCifar) {
        this(height, width, channels, null, train, useSpecialPreProcessCifar);
    }

    public CifarLoader(int height, int width, int channels, ImageTransform imgTransform, boolean train, boolean useSpecialPreProcessCifar) {
        this(height, width, channels, imgTransform, train, useSpecialPreProcessCifar, true);
    }

    public CifarLoader(int height, int width, int channels, ImageTransform imgTransform, boolean train, boolean useSpecialPreProcessCifar, boolean shuffle) {
        this(height, width, channels, imgTransform, train, useSpecialPreProcessCifar, null, System.currentTimeMillis(), shuffle);
    }

    public CifarLoader(int height, int width, int channels, ImageTransform imgTransform, boolean train, boolean useSpecialPreProcessCifar, File fullDir, long seed, boolean shuffle) {
        super((long)height, (long)width, (long)channels, imgTransform);
        this.height = height;
        this.width = width;
        this.channels = channels;
        this.train = train;
        this.useSpecialPreProcessCifar = useSpecialPreProcessCifar;
        this.seed = seed;
        this.shuffle = shuffle;
        this.fullDir = fullDir == null ? CifarLoader.getDefaultDirectory() : fullDir;
        this.meanVarPath = new File(this.fullDir, "meanVarPath.txt");
        this.trainFilesSerialized = FilenameUtils.concat((String)this.fullDir.toString(), (String)"cifar_train_serialized");
        this.testFilesSerialized = FilenameUtils.concat((String)this.fullDir.toString(), (String)"cifar_test_serialized.ser");
        this.load();
    }

    @Override
    public INDArray asRowVector(File f) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray asMatrix(File f) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        throw new UnsupportedOperationException();
    }

    private void defineLabels() {
        try {
            String line;
            File path = new File(this.fullDir, labelFileName);
            BufferedReader br = new BufferedReader(new FileReader(path));
            while ((line = br.readLine()) != null) {
                this.labels.add(line);
            }
        }
        catch (IOException e) {
            log.error("", (Throwable)e);
        }
    }

    protected void load() {
        if (!this.cifarRawFilesExist() && !this.fullDir.exists()) {
            this.fullDir.mkdir();
            log.info("Downloading CIFAR data set");
            cifar.download(true, 3, 10000, 100000);
        }
        try {
            Collection subFiles = FileUtils.listFiles((File)this.fullDir, (String[])new String[]{"bin"}, (boolean)true);
            Iterator trainIter = subFiles.iterator();
            this.trainInputStream = new SequenceInputStream(new FileInputStream((File)trainIter.next()), new FileInputStream((File)trainIter.next()));
            while (trainIter.hasNext()) {
                File nextFile = (File)trainIter.next();
                if (TESTFILENAME.equals(nextFile.getName())) continue;
                this.trainInputStream = new SequenceInputStream(this.trainInputStream, new FileInputStream(nextFile));
            }
            this.testInputStream = new FileInputStream(new File(this.fullDir, TESTFILENAME));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (this.labels.isEmpty()) {
            this.defineLabels();
        }
        if (this.useSpecialPreProcessCifar && this.train && !this.cifarProcessedFilesExists()) {
            for (int i = this.fileNum + 1; i <= TRAINFILENAMES.length; ++i) {
                this.inputStream = this.trainInputStream;
                DataSet result = this.convertDataSet(10000);
                result.save(new File(this.trainFilesSerialized + i + ".ser"));
            }
            this.inputStream = this.testInputStream;
            DataSet result = this.convertDataSet(10000);
            result.save(new File(this.testFilesSerialized));
        }
        this.setInputStream();
    }

    private boolean cifarRawFilesExist() {
        File f = new File(this.fullDir, TESTFILENAME);
        if (!f.exists()) {
            return false;
        }
        for (String name : TRAINFILENAMES) {
            f = new File(this.fullDir, name);
            if (f.exists()) continue;
            return false;
        }
        return true;
    }

    private boolean cifarProcessedFilesExists() {
        File f;
        return !(this.train ? !(f = new File(this.trainFilesSerialized + "1.ser")).exists() : !(f = new File(this.testFilesSerialized)).exists());
    }

    public Mat convertCifar(Mat orgImage) {
        ++this.numExamples;
        Mat resImage = new Mat();
        OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
        ColorConversionTransform yuvTransform = new ColorConversionTransform(new Random(this.seed), 36);
        EqualizeHistTransform histEqualization = new EqualizeHistTransform(new Random(this.seed), 36);
        if (converter != null) {
            ImageWritable writable = new ImageWritable(converter.convert(orgImage));
            writable = (ImageWritable)yuvTransform.transform(writable);
            writable = (ImageWritable)histEqualization.transform(writable);
            resImage = converter.convert(writable.getFrame());
        }
        return resImage;
    }

    public void normalizeCifar(File fileName) {
        DataSet result = new DataSet();
        result.load(fileName);
        if (!this.meanStdStored && this.train) {
            this.uMean = Math.abs(this.uMean / (double)this.numExamples);
            this.uStd = Math.sqrt(this.uStd);
            this.vMean = Math.abs(this.vMean / (double)this.numExamples);
            this.vStd = Math.sqrt(this.vStd);
            try {
                FileUtils.write((File)this.meanVarPath, (CharSequence)(this.uMean + "," + this.uStd + "," + this.vMean + "," + this.vStd));
            }
            catch (IOException e) {
                log.error("", (Throwable)e);
            }
            this.meanStdStored = true;
        } else if (this.uMean == 0.0 && this.meanStdStored) {
            try {
                String[] values = FileUtils.readFileToString((File)this.meanVarPath).split(",");
                this.uMean = Double.parseDouble(values[0]);
                this.uStd = Double.parseDouble(values[1]);
                this.vMean = Double.parseDouble(values[2]);
                this.vStd = Double.parseDouble(values[3]);
            }
            catch (IOException e) {
                log.error("", (Throwable)e);
            }
        }
        for (int i = 0; i < result.numExamples(); ++i) {
            INDArray newFeatures = result.get(i).getFeatures();
            newFeatures.tensorAlongDimension(0L, new int[]{0, 2, 3}).divi((Number)255);
            newFeatures.tensorAlongDimension(1L, new int[]{0, 2, 3}).subi((Number)this.uMean).divi((Number)this.uStd);
            newFeatures.tensorAlongDimension(2L, new int[]{0, 2, 3}).subi((Number)this.vMean).divi((Number)this.vStd);
            result.get(i).setFeatures(newFeatures);
        }
        result.save(fileName);
    }

    public Pair<INDArray, Mat> convertMat(byte[] byteFeature) {
        INDArray label = FeatureUtil.toOutcomeVector((long)byteFeature[0], (long)10L);
        Mat image = new Mat(32, 32, opencv_core.CV_8UC((int)3));
        ByteBuffer imageData = (ByteBuffer)image.createBuffer();
        for (int i = 0; i < 1024; ++i) {
            imageData.put(3 * i, byteFeature[i + 1 + 2048]);
            imageData.put(3 * i + 1, byteFeature[i + 1 + 1024]);
            imageData.put(3 * i + 2, byteFeature[i + 1]);
        }
        return new Pair((Object)label, (Object)image);
    }

    public DataSet convertDataSet(int num) {
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        byte[] byteFeature = new byte[3073];
        try {
            for (int batchNumCount = 0; batchNumCount != num && this.inputStream.read(byteFeature) != -1; ++batchNumCount) {
                Pair<INDArray, Mat> matConversion = this.convertMat(byteFeature);
                try {
                    dataSets.add(new DataSet(this.asMatrix((Mat)matConversion.getSecond()), (INDArray)matConversion.getFirst()));
                    continue;
                }
                catch (Exception e) {
                    log.error("", (Throwable)e);
                    break;
                }
            }
        }
        catch (IOException e) {
            log.error("", (Throwable)e);
        }
        if (dataSets.size() == 0) {
            return new DataSet();
        }
        DataSet result = DataSet.merge(dataSets);
        for (DataSet data : result) {
            try {
                if (this.useSpecialPreProcessCifar) {
                    INDArray uChannel = data.getFeatures().tensorAlongDimension(1L, new int[]{0, 2, 3});
                    INDArray vChannel = data.getFeatures().tensorAlongDimension(2L, new int[]{0, 2, 3});
                    double uTempMean = uChannel.meanNumber().doubleValue();
                    this.uStd += this.varManual(uChannel, uTempMean);
                    this.uMean += uTempMean;
                    double vTempMean = vChannel.meanNumber().doubleValue();
                    this.vStd += this.varManual(vChannel, vTempMean);
                    this.vMean += vTempMean;
                    data.setFeatures(data.getFeatures().div((Number)255));
                    continue;
                }
                data.setFeatures(data.getFeatures().div((Number)255));
            }
            catch (IllegalArgumentException e) {
                throw new IllegalStateException("The number of channels must be 3 to special preProcess Cifar with.");
            }
        }
        if (this.shuffle && num > 1) {
            result.shuffle(this.seed);
        }
        return result;
    }

    public double varManual(INDArray x, double mean) {
        INDArray xSubMean = x.sub((Number)mean);
        INDArray squared = xSubMean.muli(xSubMean);
        double accum = Nd4j.getExecutioner().execAndReturn((ReduceOp)new Sum(squared, new int[0])).getFinalResult().doubleValue();
        return accum / (double)x.ravel().length();
    }

    public DataSet next(int batchSize) {
        return this.next(batchSize, 0);
    }

    public DataSet next(int batchSize, int exampleNum) {
        DataSet result;
        ArrayList<DataSet> temp = new ArrayList<DataSet>();
        if (this.cifarProcessedFilesExists() && this.useSpecialPreProcessCifar) {
            if (exampleNum == 0 || exampleNum / this.fileNum == 10000 && this.train) {
                ++this.fileNum;
                if (this.train) {
                    this.loadDS.load(new File(this.trainFilesSerialized + this.fileNum + ".ser"));
                }
                this.loadDS.load(new File(this.testFilesSerialized));
                if (this.shuffle && batchSize > 1) {
                    this.loadDS.shuffle(this.seed);
                }
                this.loadDSIndex = 0;
            }
            for (int i = 0; i < batchSize && this.loadDS.get(this.loadDSIndex) != null; ++i) {
                temp.add(this.loadDS.get(this.loadDSIndex));
                ++this.loadDSIndex;
            }
            result = temp.size() > 1 ? DataSet.merge(temp) : (DataSet)temp.get(0);
        } else {
            result = this.convertDataSet(batchSize);
        }
        return result;
    }

    public InputStream getInputStream() {
        return this.inputStream;
    }

    public void setInputStream() {
        this.inputStream = this.train ? this.trainInputStream : this.testInputStream;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void reset() {
        this.numExamples = 0;
        this.fileNum = 0;
        this.load();
    }
}

