/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.cmdline;

import java.io.OutputStream;
import java.io.PrintStream;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import opennlp.tools.util.Span;
import opennlp.tools.util.eval.FMeasure;
import opennlp.tools.util.eval.Mean;

public abstract class FineGrainedReportListener {
    private static final char[] alpha = new char[]{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'};
    private final PrintStream printStream;
    private final Stats stats = new Stats();

    public FineGrainedReportListener(PrintStream printStream) {
        this.printStream = printStream;
    }

    public FineGrainedReportListener(OutputStream outputStream) {
        this.printStream = new PrintStream(outputStream);
    }

    private static String generateAlphaLabel(int index) {
        char[] labelChars = new char[3];
        for (int i = 2; i >= 0; --i) {
            if (index >= 0) {
                labelChars[i] = alpha[index % alpha.length];
                index = index / alpha.length - 1;
                continue;
            }
            labelChars[i] = 32;
        }
        return new String(labelChars);
    }

    public abstract void writeReport();

    protected Stats getStats() {
        return this.stats;
    }

    private long getNumberOfSentences() {
        return this.stats.getNumberOfSentences();
    }

    private double getAverageSentenceSize() {
        return this.stats.getAverageSentenceSize();
    }

    private int getMinSentenceSize() {
        return this.stats.getMinSentenceSize();
    }

    private int getMaxSentenceSize() {
        return this.stats.getMaxSentenceSize();
    }

    private int getNumberOfTags() {
        return this.stats.getNumberOfTags();
    }

    private double getAccuracy() {
        return this.stats.getAccuracy();
    }

    private double getTokenAccuracy(String token) {
        return this.stats.getTokenAccuracy(token);
    }

    private SortedSet<String> getTokensOrderedByFrequency() {
        return this.stats.getTokensOrderedByFrequency();
    }

    private int getTokenFrequency(String token) {
        return this.stats.getTokenFrequency(token);
    }

    private int getTokenErrors(String token) {
        return this.stats.getTokenErrors(token);
    }

    private SortedSet<String> getTokensOrderedByNumberOfErrors() {
        return this.stats.getTokensOrderedByNumberOfErrors();
    }

    private SortedSet<String> getTagsOrderedByErrors() {
        return this.stats.getTagsOrderedByErrors();
    }

    private int getTagFrequency(String tag) {
        return this.stats.getTagFrequency(tag);
    }

    private int getTagErrors(String tag) {
        return this.stats.getTagErrors(tag);
    }

    private double getTagPrecision(String tag) {
        return this.stats.getTagPrecision(tag);
    }

    private double getTagRecall(String tag) {
        return this.stats.getTagRecall(tag);
    }

    private double getTagFMeasure(String tag) {
        return this.stats.getTagFMeasure(tag);
    }

    private SortedSet<String> getConfusionMatrixTagset() {
        return this.stats.getConfusionMatrixTagset();
    }

    private SortedSet<String> getConfusionMatrixTagset(String token) {
        return this.stats.getConfusionMatrixTagset(token);
    }

    private double[][] getConfusionMatrix() {
        return this.stats.getConfusionMatrix();
    }

    private double[][] getConfusionMatrix(String token) {
        return this.stats.getConfusionMatrix(token);
    }

    private String matrixToString(SortedSet<String> tagset, double[][] data, boolean filter) {
        int i;
        int initialIndex = 0;
        String[] tags = tagset.toArray(new String[tagset.size()]);
        StringBuilder sb = new StringBuilder();
        int minColumnSize = Integer.MIN_VALUE;
        String[][] matrix = new String[data.length][data[0].length];
        for (int i2 = 0; i2 < data.length; ++i2) {
            int j;
            for (j = 0; j < data[i2].length - 1; ++j) {
                String string = matrix[i2][j] = data[i2][j] > 0.0 ? Integer.toString((int)data[i2][j]) : ".";
                if (minColumnSize >= matrix[i2][j].length()) continue;
                minColumnSize = matrix[i2][j].length();
            }
            matrix[i2][j] = MessageFormat.format("{0,number,#.##%}", data[i2][j]);
            if (data[i2][j] != 1.0 || !filter) continue;
            initialIndex = i2 + 1;
        }
        String headerFormat = "%" + (minColumnSize + 2) + "s ";
        String cellFormat = "%" + (minColumnSize + 2) + "s ";
        String diagFormat = " %" + (minColumnSize + 2) + "s";
        for (i = initialIndex; i < tagset.size(); ++i) {
            sb.append(String.format(headerFormat, FineGrainedReportListener.generateAlphaLabel(i - initialIndex).trim()));
        }
        sb.append("| Accuracy | <-- classified as\n");
        for (i = initialIndex; i < data.length; ++i) {
            int j;
            for (j = initialIndex; j < data[i].length - 1; ++j) {
                if (i == j) {
                    String val = "<" + matrix[i][j] + ">";
                    sb.append(String.format(diagFormat, val));
                    continue;
                }
                sb.append(String.format(cellFormat, matrix[i][j]));
            }
            sb.append(String.format("|   %-6s |   %3s = ", matrix[i][j], FineGrainedReportListener.generateAlphaLabel(i - initialIndex))).append(tags[i]);
            sb.append("\n");
        }
        return sb.toString();
    }

    protected void printGeneralStatistics() {
        this.printHeader("Evaluation summary");
        this.printStream.append(String.format("%21s: %6s", "Number of sentences", Long.toString(this.getNumberOfSentences()))).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Min sentence size", this.getMinSentenceSize())).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Max sentence size", this.getMaxSentenceSize())).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Average sentence size", MessageFormat.format("{0,number,#.##}", this.getAverageSentenceSize()))).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Tags count", this.getNumberOfTags())).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Accuracy", MessageFormat.format("{0,number,#.##%}", this.getAccuracy()))).append("\n");
        this.printFooter("Evaluation Corpus Statistics");
    }

    protected void printTokenOcurrenciesRank() {
        this.printHeader("Most frequent tokens");
        SortedSet<String> toks = this.getTokensOrderedByFrequency();
        int maxLines = 20;
        int maxTokSize = 5;
        int count = 0;
        Iterator tokIterator = toks.iterator();
        while (tokIterator.hasNext() && count++ < 20) {
            String tok = (String)tokIterator.next();
            if (tok.length() <= maxTokSize) continue;
            maxTokSize = tok.length();
        }
        int tableSize = maxTokSize + 19;
        String format = "| %3s | %6s | %" + maxTokSize + "s |";
        this.printLine(tableSize);
        this.printStream.append(String.format(format, "Pos", "Count", "Token")).append("\n");
        this.printLine(tableSize);
        count = 0;
        tokIterator = toks.iterator();
        while (tokIterator.hasNext() && count++ < 20) {
            String tok = (String)tokIterator.next();
            int ocurrencies = this.getTokenFrequency(tok);
            this.printStream.append(String.format(format, count, ocurrencies, tok)).append("\n");
        }
        this.printLine(tableSize);
        this.printFooter("Most frequent tokens");
    }

    protected void printTokenErrorRank() {
        this.printHeader("Tokens with the highest number of errors");
        this.printStream.append("\n");
        SortedSet<String> toks = this.getTokensOrderedByNumberOfErrors();
        int maxTokenSize = 5;
        int count = 0;
        Iterator tokIterator = toks.iterator();
        while (tokIterator.hasNext() && count++ < 20) {
            String tok = (String)tokIterator.next();
            if (tok.length() <= maxTokenSize) continue;
            maxTokenSize = tok.length();
        }
        int tableSize = 31 + maxTokenSize;
        String format = "| %" + maxTokenSize + "s | %6s | %5s | %7s |\n";
        this.printLine(tableSize);
        this.printStream.append(String.format(format, "Token", "Errors", "Count", "% Err"));
        this.printLine(tableSize);
        count = 0;
        tokIterator = toks.iterator();
        while (tokIterator.hasNext() && count++ < 20) {
            String tok = (String)tokIterator.next();
            int ocurrencies = this.getTokenFrequency(tok);
            int errors = this.getTokenErrors(tok);
            String rate = MessageFormat.format("{0,number,#.##%}", (double)errors / (double)ocurrencies);
            this.printStream.append(String.format(format, tok, errors, ocurrencies, rate));
        }
        this.printLine(tableSize);
        this.printFooter("Tokens with the highest number of errors");
    }

    protected void printTagsErrorRank() {
        this.printHeader("Detailed Accuracy By Tag");
        SortedSet<String> tags = this.getTagsOrderedByErrors();
        this.printStream.append("\n");
        int maxTagSize = 3;
        for (String t : tags) {
            if (t.length() <= maxTagSize) continue;
            maxTagSize = t.length();
        }
        int tableSize = 65 + maxTagSize;
        String headerFormat = "| %" + maxTagSize + "s | %6s | %6s | %7s | %9s | %6s | %9s |\n";
        String format = "| %" + maxTagSize + "s | %6s | %6s | %-7s | %-9s | %-6s | %-9s |\n";
        this.printLine(tableSize);
        this.printStream.append(String.format(headerFormat, "Tag", "Errors", "Count", "% Err", "Precision", "Recall", "F-Measure"));
        this.printLine(tableSize);
        for (String tag : tags) {
            int ocurrencies = this.getTagFrequency(tag);
            int errors = this.getTagErrors(tag);
            String rate = MessageFormat.format("{0,number,#.###}", (double)errors / (double)ocurrencies);
            double p = this.getTagPrecision(tag);
            double r = this.getTagRecall(tag);
            double f = this.getTagFMeasure(tag);
            this.printStream.append(String.format(format, tag, errors, ocurrencies, rate, MessageFormat.format("{0,number,#.###}", p > 0.0 ? p : 0.0), MessageFormat.format("{0,number,#.###}", r > 0.0 ? r : 0.0), MessageFormat.format("{0,number,#.###}", f > 0.0 ? f : 0.0)));
        }
        this.printLine(tableSize);
        this.printFooter("Tags with the highest number of errors");
    }

    protected void printGeneralConfusionTable() {
        this.printHeader("Confusion matrix");
        SortedSet<String> labels = this.getConfusionMatrixTagset();
        double[][] confusionMatrix = this.getConfusionMatrix();
        this.printStream.append("\nTags with 100% accuracy: ");
        int line = 0;
        for (String label : labels) {
            if (confusionMatrix[line][confusionMatrix[0].length - 1] == 1.0) {
                this.printStream.append(label).append(" (").append(Integer.toString((int)confusionMatrix[line][line])).append(") ");
            }
            ++line;
        }
        this.printStream.append("\n\n");
        this.printStream.append(this.matrixToString(labels, confusionMatrix, true));
        this.printFooter("Confusion matrix");
    }

    protected void printDetailedConfusionMatrix() {
        this.printHeader("Confusion matrix for tokens");
        this.printStream.append("  sorted by number of errors\n");
        SortedSet<String> toks = this.getTokensOrderedByNumberOfErrors();
        for (String t : toks) {
            double acc = this.getTokenAccuracy(t);
            if (!(acc < 1.0)) continue;
            this.printStream.append("\n[").append(t).append("]\n").append(String.format("%12s: %-8s", "Accuracy", MessageFormat.format("{0,number,#.##%}", acc))).append("\n");
            this.printStream.append(String.format("%12s: %-8s", "Ocurrencies", Integer.toString(this.getTokenFrequency(t)))).append("\n");
            this.printStream.append(String.format("%12s: %-8s", "Errors", Integer.toString(this.getTokenErrors(t)))).append("\n");
            SortedSet<String> labels = this.getConfusionMatrixTagset(t);
            double[][] confusionMatrix = this.getConfusionMatrix(t);
            this.printStream.append(this.matrixToString(labels, confusionMatrix, false));
        }
        this.printFooter("Confusion matrix for tokens");
    }

    private void printHeader(String text) {
        this.printStream.append("=== ").append(text).append(" ===\n");
    }

    private void printFooter(String text) {
        this.printStream.append("\n<-end> ").append(text).append("\n\n");
    }

    private void printLine(int size) {
        for (int i = 0; i < size; ++i) {
            this.printStream.append("-");
        }
        this.printStream.append("\n");
    }

    public Comparator<String> getMatrixLabelComparator(Map<String, ConfusionMatrixLine> confusionMatrix) {
        return new MatrixLabelComparator(confusionMatrix);
    }

    public Comparator<String> getLabelComparator(Map<String, Counter> map) {
        return new SimpleLabelComparator(map);
    }

    public class Stats {
        private final Mean accuracy = new Mean();
        private final Mean averageSentenceLength = new Mean();
        private final Map<String, Mean> tokAccuracies = new HashMap<String, Mean>();
        private final Map<String, Counter> tokOcurrencies = new HashMap<String, Counter>();
        private final Map<String, Counter> tokErrors = new HashMap<String, Counter>();
        private final Map<String, Counter> tagOcurrencies = new HashMap<String, Counter>();
        private final Map<String, Counter> tagErrors = new HashMap<String, Counter>();
        private final Map<String, FMeasure> tagFMeasure = new HashMap<String, FMeasure>();
        private final Map<String, ConfusionMatrixLine> generalConfusionMatrix = new HashMap<String, ConfusionMatrixLine>();
        private final Map<String, Map<String, ConfusionMatrixLine>> tokenConfusionMatrix = new HashMap<String, Map<String, ConfusionMatrixLine>>();
        private int minimalSentenceLength = Integer.MAX_VALUE;
        private int maximumSentenceLength = Integer.MIN_VALUE;

        public void add(String[] toks, String[] refs, String[] preds) {
            int length = toks.length;
            this.averageSentenceLength.add(length);
            if (this.minimalSentenceLength > length) {
                this.minimalSentenceLength = length;
            }
            if (this.maximumSentenceLength < length) {
                this.maximumSentenceLength = length;
            }
            this.updateTagFMeasure(refs, preds);
            for (int i = 0; i < toks.length; ++i) {
                this.commit(toks[i], refs[i], preds[i]);
            }
        }

        public void add(int length, String ref, String pred) {
            this.averageSentenceLength.add(length);
            if (this.minimalSentenceLength > length) {
                this.minimalSentenceLength = length;
            }
            if (this.maximumSentenceLength < length) {
                this.maximumSentenceLength = length;
            }
            String[] refs = new String[]{ref};
            String[] preds = new String[]{pred};
            this.updateTagFMeasure(refs, preds);
            this.commit("", ref, pred);
        }

        public void add(String[] text, String ref, String pred) {
            int length = text.length;
            this.add(length, ref, pred);
        }

        public void add(CharSequence text, String ref, String pred) {
            int length = text.length();
            this.add(length, ref, pred);
        }

        private void commit(String tok, String ref, String pred) {
            if (!this.tokAccuracies.containsKey(tok)) {
                this.tokAccuracies.put(tok, new Mean());
                this.tokOcurrencies.put(tok, new Counter());
                this.tokErrors.put(tok, new Counter());
            }
            this.tokOcurrencies.get(tok).increment();
            if (!this.tagOcurrencies.containsKey(ref)) {
                this.tagOcurrencies.put(ref, new Counter());
                this.tagErrors.put(ref, new Counter());
            }
            this.tagOcurrencies.get(ref).increment();
            if (ref.equals(pred)) {
                this.tokAccuracies.get(tok).add(1.0);
                this.accuracy.add(1.0);
            } else {
                this.tokAccuracies.get(tok).add(0.0);
                this.tokErrors.get(tok).increment();
                this.tagErrors.get(ref).increment();
                this.accuracy.add(0.0);
            }
            if (!this.generalConfusionMatrix.containsKey(ref)) {
                this.generalConfusionMatrix.put(ref, new ConfusionMatrixLine(ref));
            }
            this.generalConfusionMatrix.get(ref).increment(pred);
            if (!this.tokenConfusionMatrix.containsKey(tok)) {
                this.tokenConfusionMatrix.put(tok, new HashMap());
            }
            if (!this.tokenConfusionMatrix.get(tok).containsKey(ref)) {
                this.tokenConfusionMatrix.get(tok).put(ref, new ConfusionMatrixLine(ref));
            }
            this.tokenConfusionMatrix.get(tok).get(ref).increment(pred);
        }

        private void updateTagFMeasure(String[] refs, String[] preds) {
            HashSet<String> tags = new HashSet<String>(Arrays.asList(refs));
            tags.addAll(Arrays.asList(preds));
            for (String tag : tags) {
                ArrayList<Span> reference = new ArrayList<Span>();
                ArrayList<Span> prediction = new ArrayList<Span>();
                for (int i = 0; i < refs.length; ++i) {
                    if (refs[i].equals(tag)) {
                        reference.add(new Span(i, i + 1));
                    }
                    if (!preds[i].equals(tag)) continue;
                    prediction.add(new Span(i, i + 1));
                }
                if (!this.tagFMeasure.containsKey(tag)) {
                    this.tagFMeasure.put(tag, new FMeasure());
                }
                this.tagFMeasure.get(tag).updateScores(reference.toArray(new Span[reference.size()]), prediction.toArray(new Span[prediction.size()]));
            }
        }

        private double getAccuracy() {
            return this.accuracy.mean();
        }

        private int getNumberOfTags() {
            return this.tagOcurrencies.keySet().size();
        }

        private long getNumberOfSentences() {
            return this.averageSentenceLength.count();
        }

        private double getAverageSentenceSize() {
            return this.averageSentenceLength.mean();
        }

        private int getMinSentenceSize() {
            return this.minimalSentenceLength;
        }

        private int getMaxSentenceSize() {
            return this.maximumSentenceLength;
        }

        private double getTokenAccuracy(String token) {
            return this.tokAccuracies.get(token).mean();
        }

        private int getTokenErrors(String token) {
            return this.tokErrors.get(token).value();
        }

        private int getTokenFrequency(String token) {
            return this.tokOcurrencies.get(token).value();
        }

        private SortedSet<String> getTokensOrderedByFrequency() {
            TreeSet<String> toks = new TreeSet<String>(new SimpleLabelComparator(this.tokOcurrencies));
            toks.addAll(this.tokOcurrencies.keySet());
            return Collections.unmodifiableSortedSet(toks);
        }

        private SortedSet<String> getTokensOrderedByNumberOfErrors() {
            TreeSet<String> toks = new TreeSet<String>(new SimpleLabelComparator(this.tokErrors));
            toks.addAll(this.tokErrors.keySet());
            return toks;
        }

        private int getTagFrequency(String tag) {
            return this.tagOcurrencies.get(tag).value();
        }

        private int getTagErrors(String tag) {
            return this.tagErrors.get(tag).value();
        }

        private double getTagFMeasure(String tag) {
            return this.tagFMeasure.get(tag).getFMeasure();
        }

        private double getTagRecall(String tag) {
            return this.tagFMeasure.get(tag).getRecallScore();
        }

        private double getTagPrecision(String tag) {
            return this.tagFMeasure.get(tag).getPrecisionScore();
        }

        private SortedSet<String> getTagsOrderedByErrors() {
            TreeSet<String> tags = new TreeSet<String>(FineGrainedReportListener.this.getLabelComparator(this.tagErrors));
            tags.addAll(this.tagErrors.keySet());
            return Collections.unmodifiableSortedSet(tags);
        }

        private SortedSet<String> getConfusionMatrixTagset() {
            return this.getConfusionMatrixTagset(this.generalConfusionMatrix);
        }

        private double[][] getConfusionMatrix() {
            return this.createConfusionMatrix(this.getConfusionMatrixTagset(), this.generalConfusionMatrix);
        }

        private SortedSet<String> getConfusionMatrixTagset(String token) {
            return this.getConfusionMatrixTagset(this.tokenConfusionMatrix.get(token));
        }

        private double[][] getConfusionMatrix(String token) {
            return this.createConfusionMatrix(this.getConfusionMatrixTagset(token), this.tokenConfusionMatrix.get(token));
        }

        private double[][] createConfusionMatrix(SortedSet<String> tagset, Map<String, ConfusionMatrixLine> data) {
            int size = tagset.size();
            double[][] matrix = new double[size][size + 1];
            int line = 0;
            for (String ref : tagset) {
                int column = 0;
                for (String pred : tagset) {
                    matrix[line][column] = data.get(ref) != null ? (double)data.get(ref).getValue(pred) : 0.0;
                    ++column;
                }
                matrix[line][column] = data.get(ref) != null ? data.get(ref).getAccuracy() : 0.0;
                ++line;
            }
            return matrix;
        }

        private SortedSet<String> getConfusionMatrixTagset(Map<String, ConfusionMatrixLine> data) {
            TreeSet<String> tags = new TreeSet<String>(FineGrainedReportListener.this.getMatrixLabelComparator(data));
            tags.addAll(data.keySet());
            LinkedList<String> col = new LinkedList<String>();
            for (String t : tags) {
                col.addAll(data.get((Object)t).line.keySet());
            }
            tags.addAll(col);
            return Collections.unmodifiableSortedSet(tags);
        }
    }

    public static class Counter {
        private int c = 0;

        private void increment() {
            ++this.c;
        }

        public int value() {
            return this.c;
        }
    }

    public static class ConfusionMatrixLine {
        private Map<String, Counter> line = new HashMap<String, Counter>();
        private String ref;
        private int total = 0;
        private int correct = 0;
        private double acc = -1.0;

        private ConfusionMatrixLine(String ref) {
            this.ref = ref;
        }

        private void increment(String column) {
            ++this.total;
            if (column.equals(this.ref)) {
                ++this.correct;
            }
            if (!this.line.containsKey(column)) {
                this.line.put(column, new Counter());
            }
            this.line.get(column).increment();
        }

        public double getAccuracy() {
            if (StrictMath.abs(this.acc - 1.0) < 1.0E-10) {
                if (this.total == 0) {
                    this.acc = 0.0;
                }
                this.acc = (double)this.correct / (double)this.total;
            }
            return this.acc;
        }

        public int getValue(String column) {
            Counter c = this.line.get(column);
            if (c == null) {
                return 0;
            }
            return c.value();
        }
    }

    public static class GroupedLabelComparator
    implements Comparator<String> {
        private final HashMap<String, Integer> categoryCounter;
        private Map<String, Counter> labelCounter;

        public GroupedLabelComparator(Map<String, Counter> map) {
            this.labelCounter = map;
            this.categoryCounter = new HashMap();
            for (Map.Entry<String, Counter> entry : this.labelCounter.entrySet()) {
                String key = entry.getKey();
                Counter value = entry.getValue();
                String category = key.contains("-") ? key.split("-")[0] : key;
                int currentCount = this.categoryCounter.getOrDefault(category, 0);
                this.categoryCounter.put(category, currentCount + value.value());
            }
        }

        @Override
        public int compare(String o1, String o2) {
            if (o1.equals(o2)) {
                return 0;
            }
            String c1 = o1;
            String c2 = o2;
            if (o1.contains("-")) {
                c1 = o1.split("-")[0];
            }
            if (o2.contains("-")) {
                c2 = o2.split("-")[0];
            }
            if (c1.equals(c2)) {
                int r2;
                Counter t1 = this.labelCounter.get(o1);
                Counter t2 = this.labelCounter.get(o2);
                if (t1 == null || t2 == null) {
                    if (t1 == null) {
                        return 1;
                    }
                    return -1;
                }
                int r1 = t1.value();
                if (r1 == (r2 = t2.value())) {
                    return o1.compareTo(o2);
                }
                if (r2 > r1) {
                    return 1;
                }
                return -1;
            }
            Integer t1 = this.categoryCounter.get(c1);
            Integer t2 = this.categoryCounter.get(c2);
            if (t1 == null || t2 == null) {
                if (t1 == null) {
                    return 1;
                }
                return -1;
            }
            if (t1.equals(t2)) {
                return o1.compareTo(o2);
            }
            if (t2 > t1) {
                return 1;
            }
            return -1;
        }
    }

    public static class SimpleLabelComparator
    implements Comparator<String> {
        private Map<String, Counter> map;

        public SimpleLabelComparator(Map<String, Counter> map) {
            this.map = map;
        }

        @Override
        public int compare(String o1, String o2) {
            if (o1.equals(o2)) {
                return 0;
            }
            int e1 = 0;
            int e2 = 0;
            if (this.map.containsKey(o1)) {
                e1 = this.map.get(o1).value();
            }
            if (this.map.containsKey(o2)) {
                e2 = this.map.get(o2).value();
            }
            if (e1 == e2) {
                return o1.compareTo(o2);
            }
            return e2 - e1;
        }
    }

    public static class GroupedMatrixLabelComparator
    implements Comparator<String> {
        private final HashMap<String, Double> categoryAccuracy;
        private Map<String, ConfusionMatrixLine> confusionMatrix;

        public GroupedMatrixLabelComparator(Map<String, ConfusionMatrixLine> confusionMatrix) {
            this.confusionMatrix = confusionMatrix;
            this.categoryAccuracy = new HashMap();
            for (Map.Entry<String, ConfusionMatrixLine> entry : confusionMatrix.entrySet()) {
                String key = entry.getKey();
                ConfusionMatrixLine confusionMatrixLine = entry.getValue();
                String category = key.contains("-") ? key.split("-")[0] : key;
                double currentAccuracy = this.categoryAccuracy.getOrDefault(category, 0.0);
                this.categoryAccuracy.put(category, currentAccuracy + confusionMatrixLine.getAccuracy());
            }
        }

        @Override
        public int compare(String o1, String o2) {
            if (o1.equals(o2)) {
                return 0;
            }
            String c1 = o1;
            String c2 = o2;
            if (o1.contains("-")) {
                c1 = o1.split("-")[0];
            }
            if (o2.contains("-")) {
                c2 = o2.split("-")[0];
            }
            if (c1.equals(c2)) {
                double r2;
                ConfusionMatrixLine t1 = this.confusionMatrix.get(o1);
                ConfusionMatrixLine t2 = this.confusionMatrix.get(o2);
                if (t1 == null || t2 == null) {
                    if (t1 == null) {
                        return 1;
                    }
                    return -1;
                }
                double r1 = t1.getAccuracy();
                if (r1 == (r2 = t2.getAccuracy())) {
                    return o1.compareTo(o2);
                }
                if (r2 > r1) {
                    return 1;
                }
                return -1;
            }
            Double t1 = this.categoryAccuracy.get(c1);
            Double t2 = this.categoryAccuracy.get(c2);
            if (t1 == null || t2 == null) {
                if (t1 == null) {
                    return 1;
                }
                return -1;
            }
            if (t1.equals(t2)) {
                return o1.compareTo(o2);
            }
            if (t2 > t1) {
                return 1;
            }
            return -1;
        }
    }

    public static class MatrixLabelComparator
    implements Comparator<String> {
        private Map<String, ConfusionMatrixLine> confusionMatrix;

        public MatrixLabelComparator(Map<String, ConfusionMatrixLine> confusionMatrix) {
            this.confusionMatrix = confusionMatrix;
        }

        @Override
        public int compare(String o1, String o2) {
            double r2;
            if (o1.equals(o2)) {
                return 0;
            }
            ConfusionMatrixLine t1 = this.confusionMatrix.get(o1);
            ConfusionMatrixLine t2 = this.confusionMatrix.get(o2);
            if (t1 == null || t2 == null) {
                if (t1 == null) {
                    return 1;
                }
                return -1;
            }
            double r1 = t1.getAccuracy();
            if (r1 == (r2 = t2.getAccuracy())) {
                return o1.compareTo(o2);
            }
            if (r2 > r1) {
                return 1;
            }
            return -1;
        }
    }
}

