package org.deeplearning4j.zoo.util;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.resources.Downloader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/zoo/util/BaseLabels.class */
public abstract class BaseLabels implements Labels {
    protected ArrayList<String> labels;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseLabels() throws IOException {
        this.labels = getLabels();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseLabels(String str) throws IOException {
        this.labels = getLabels(str);
    }

    protected ArrayList<String> getLabels() throws IOException {
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<String> getLabels(String str) throws IOException {
        ArrayList<String> arrayList = new ArrayList<>();
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(getResourceFile()));
        try {
            Scanner scanner = new Scanner(bufferedInputStream);
            while (scanner.hasNextLine()) {
                try {
                    arrayList.add(scanner.nextLine());
                } finally {
                }
            }
            scanner.close();
            bufferedInputStream.close();
            return arrayList;
        } catch (Throwable th) {
            try {
                bufferedInputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // org.deeplearning4j.zoo.util.Labels
    public String getLabel(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.labels.size(), "Invalid index: %s. Must be in range0 <= n < %s", i, this.labels.size());
        return this.labels.get(i);
    }

    @Override // org.deeplearning4j.zoo.util.Labels
    public List<List<ClassPrediction>> decodePredictions(INDArray iNDArray, int i) {
        if (iNDArray.rank() == 1) {
            iNDArray = iNDArray.reshape(1L, iNDArray.length());
        }
        Preconditions.checkState(iNDArray.size(1) == ((long) this.labels.size()), "Invalid input array: expected array with size(1) equal to numLabels (%s), got array with shape %s", Integer.valueOf(this.labels.size()), iNDArray.shape());
        long size = iNDArray.size(0);
        long size2 = iNDArray.size(1);
        if (iNDArray.isColumnVectorOrScalar()) {
            iNDArray = iNDArray.ravel();
            size = (int) iNDArray.size(0);
            size2 = (int) iNDArray.size(1);
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < size; i2++) {
            INDArray row = iNDArray.getRow(i2, true);
            INDArray sortColumns = Nd4j.sortColumns(Nd4j.vstack(new INDArray[]{Nd4j.linspace(row.dataType(), 0L, size2, 1L).reshape(1L, size2), row}), 1, false);
            ArrayList arrayList2 = new ArrayList();
            for (int i3 = 0; i3 < i; i3++) {
                int i4 = sortColumns.getInt(new int[]{0, i3});
                arrayList2.add(new ClassPrediction(i4, getLabel(i4), sortColumns.getDouble(1L, i3)));
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    protected abstract URL getURL();

    protected abstract String resourceName();

    protected abstract String resourceMD5();

    /* JADX INFO: Access modifiers changed from: protected */
    public File getResourceFile() {
        URL url = getURL();
        String url2 = url.toString();
        File file = new File(DL4JResources.getDirectory(ResourceType.RESOURCE, resourceName()), url2.substring(url2.lastIndexOf(47) + 1));
        String resourceMD5 = resourceMD5();
        if (file.exists()) {
            try {
                if (Downloader.checkMD5OfFile(resourceMD5, file)) {
                    return file;
                }
            } catch (IOException e) {
            }
            file.delete();
        }
        try {
            Downloader.download(resourceName(), url, file, resourceMD5, 3);
            return file;
        } catch (IOException e2) {
            throw new RuntimeException("Error downloading labels", e2);
        }
    }
}
