package org.deeplearning4j.optimize.listeners;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;

/* loaded from: input_file:org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.class */
public class CollectScoresIterationListener extends BaseTrainingListener {
    private int frequency;
    private int iterationCount;
    ScoreStat scoreVsIter;

    /* loaded from: input_file:org/deeplearning4j/optimize/listeners/CollectScoresIterationListener$ScoreStat.class */
    public static class ScoreStat {
        public static final int BUCKET_LENGTH = 10000;
        private int position = 0;
        private int bucketNumber = 1;
        private List<long[]> indexes = new ArrayList(1);
        private List<double[]> scores;

        public ScoreStat() {
            this.indexes.add(new long[BUCKET_LENGTH]);
            this.scores = new ArrayList(1);
            this.scores.add(new double[BUCKET_LENGTH]);
        }

        public List<long[]> getIndexes() {
            return this.indexes;
        }

        public List<double[]> getScores() {
            return this.scores;
        }

        public long[] getEffectiveIndexes() {
            return Arrays.copyOfRange(this.indexes.get(0), 0, this.position);
        }

        public double[] getEffectiveScores() {
            return Arrays.copyOfRange(this.scores.get(0), 0, this.position);
        }

        private void reallocateGuard() {
            if (this.position >= BUCKET_LENGTH * this.bucketNumber) {
                long j = EncodingHandler.THRESHOLD_LOG_FREQ_MS * this.bucketNumber;
                if (this.position == Integer.MAX_VALUE || j >= 2147483647L) {
                    this.position = 0;
                    long[] jArr = new long[BUCKET_LENGTH];
                    double[] dArr = new double[BUCKET_LENGTH];
                    this.indexes.add(jArr);
                    this.scores.add(dArr);
                } else {
                    long[] jArr2 = new long[((int) j) + BUCKET_LENGTH];
                    double[] dArr2 = new double[((int) j) + BUCKET_LENGTH];
                    System.arraycopy(this.indexes.get(this.indexes.size() - 1), 0, jArr2, 0, (int) j);
                    System.arraycopy(this.scores.get(this.scores.size() - 1), 0, dArr2, 0, (int) j);
                    this.scores.remove(this.scores.size() - 1);
                    this.indexes.remove(this.indexes.size() - 1);
                    int size = this.scores.size() == 0 ? 0 : this.scores.size() - 1;
                    this.scores.add(size, dArr2);
                    this.indexes.add(size, jArr2);
                }
                this.bucketNumber++;
            }
        }

        public void addScore(long j, double d) {
            reallocateGuard();
            this.scores.get(this.scores.size() - 1)[this.position] = d;
            this.indexes.get(this.scores.size() - 1)[this.position] = j;
            this.position++;
        }
    }

    public CollectScoresIterationListener() {
        this(1);
    }

    public CollectScoresIterationListener(int i) {
        this.iterationCount = 0;
        this.scoreVsIter = new ScoreStat();
        this.frequency = i <= 0 ? 1 : i;
    }

    @Override // org.deeplearning4j.optimize.api.BaseTrainingListener, org.deeplearning4j.optimize.api.TrainingListener
    public void iterationDone(Model model, int i, int i2) {
        int i3 = this.iterationCount + 1;
        this.iterationCount = i3;
        if (i3 % this.frequency == 0) {
            double score = model.score();
            this.scoreVsIter.reallocateGuard();
            this.scoreVsIter.addScore(i, score);
        }
    }

    public ScoreStat getScoreVsIter() {
        return this.scoreVsIter;
    }

    public void exportScores(OutputStream outputStream) throws IOException {
        exportScores(outputStream, "\t");
    }

    public void exportScores(OutputStream outputStream, String str) throws IOException {
        StringBuilder sb = new StringBuilder();
        sb.append("Iteration").append(str).append("Score");
        int size = this.scoreVsIter.indexes.size();
        int i = 0;
        while (i < size) {
            long[] jArr = this.scoreVsIter.indexes.get(i);
            double[] dArr = this.scoreVsIter.scores.get(i);
            int length = i < size - 1 ? jArr.length : this.scoreVsIter.position;
            for (int i2 = 0; i2 < length; i2++) {
                sb.append("\n").append(jArr[i2]).append(str).append(dArr[i2]);
            }
            i++;
        }
        outputStream.write(sb.toString().getBytes("UTF-8"));
    }

    public void exportScores(File file) throws IOException {
        exportScores(file, "\t");
    }

    public void exportScores(File file, String str) throws IOException {
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        try {
            exportScores(fileOutputStream, str);
            fileOutputStream.close();
        } catch (Throwable th) {
            try {
                fileOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
