/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.util;

import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.StringWriter;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class ConfusionMatrix<U> {
    private static final String CLASS_PREFIX = "C";
    private static final String FORMAT = "#.#####";
    protected DecimalFormat format;
    private int leftPadSize = 16;
    private int delimPadSize = 8;
    private boolean useRealLabels = false;
    private ConcurrentHashMap<Pair<U, U>, Integer> confTable = new ConcurrentHashMap();

    public ConfusionMatrix() {
        this.format = new DecimalFormat(FORMAT);
    }

    public ConfusionMatrix(Locale locale) {
        this.format = new DecimalFormat(FORMAT, new DecimalFormatSymbols(locale));
    }

    public String toString() {
        return this.printTable();
    }

    public void setLeftPadSize(int newPadSize) {
        this.leftPadSize = newPadSize;
    }

    public void setDelimPadSize(int newPadSize) {
        this.delimPadSize = newPadSize;
    }

    public void setUseRealLabels(boolean useRealLabels) {
        this.useRealLabels = useRealLabels;
    }

    public void add(U guess, U gold) {
        this.add(guess, gold, 1);
    }

    public synchronized void add(U guess, U gold, int increment) {
        Pair<U, U> pair = new Pair<U, U>(guess, gold);
        if (this.confTable.containsKey(pair)) {
            this.confTable.put(pair, this.confTable.get(pair) + increment);
        } else {
            this.confTable.put(pair, increment);
        }
    }

    public Integer get(U guess, U gold) {
        Pair<U, U> pair = new Pair<U, U>(guess, gold);
        if (this.confTable.containsKey(pair)) {
            return this.confTable.get(pair);
        }
        return 0;
    }

    public Set<U> uniqueLabels() {
        HashSet<Object> ret = new HashSet<Object>();
        for (Pair pair : this.confTable.keySet()) {
            ret.add(pair.first());
            ret.add(pair.second());
        }
        return ret;
    }

    public Contingency getContingency(U positiveLabel) {
        int tp = 0;
        int fp = 0;
        int tn = 0;
        int fn = 0;
        for (Pair pair : this.confTable.keySet()) {
            int count = this.confTable.get(pair);
            Object guess = pair.first();
            Object gold = pair.second();
            boolean guessP = guess.equals(positiveLabel);
            boolean goldP = gold.equals(positiveLabel);
            if (guessP && goldP) {
                tp += count;
                continue;
            }
            if (!guessP && goldP) {
                fn += count;
                continue;
            }
            if (guessP && !goldP) {
                fp += count;
                continue;
            }
            tn += count;
        }
        return new Contingency(tp, fp, tn, fn);
    }

    private List<U> sortKeys() {
        Set<U> labels = this.uniqueLabels();
        if (labels.size() == 0) {
            return Collections.emptyList();
        }
        boolean comparable = true;
        for (Object label : labels) {
            if (label instanceof Comparable) continue;
            comparable = false;
            break;
        }
        if (comparable) {
            ArrayList sorted = Generics.newArrayList();
            for (Object label : labels) {
                sorted.add(ErasureUtils.uncheckedCast(label));
            }
            Collections.sort(sorted);
            ArrayList ret = Generics.newArrayList();
            for (Object o : sorted) {
                ret.add(ErasureUtils.uncheckedCast(o));
            }
            return ret;
        }
        ArrayList<String> names = new ArrayList<String>();
        HashMap<String, U> lookup = new HashMap<String, U>();
        for (U label : labels) {
            names.add(label.toString());
            lookup.put(label.toString(), label);
        }
        Collections.sort(names);
        ArrayList ret = new ArrayList();
        for (String name : names) {
            ret.add(lookup.get(name));
        }
        return ret;
    }

    private Integer goldMarginal(U gold) {
        Integer sum = 0;
        Set<U> labels = this.uniqueLabels();
        for (U guess : labels) {
            sum = sum + this.get(guess, gold);
        }
        return sum;
    }

    private Integer guessMarginal(U guess) {
        Integer sum = 0;
        Set<U> labels = this.uniqueLabels();
        for (U gold : labels) {
            sum = sum + this.get(guess, gold);
        }
        return sum;
    }

    public String getPlaceHolder(int index, U label) {
        if (this.useRealLabels) {
            return label.toString();
        }
        return CLASS_PREFIX + (index + 1);
    }

    public String printTable() {
        String placeHolder;
        List<U> sortedLabels = this.sortKeys();
        if (this.confTable.size() == 0) {
            return "Empty table!";
        }
        StringWriter ret = new StringWriter();
        ret.write(StringUtils.padLeft("Guess/Gold", this.leftPadSize));
        for (int i = 0; i < sortedLabels.size(); ++i) {
            placeHolder = this.getPlaceHolder(i, sortedLabels.get(i));
            ret.write(StringUtils.padLeft(placeHolder, this.delimPadSize));
        }
        ret.write("    Marg. (Guess)");
        ret.write("\n");
        for (int guessI = 0; guessI < sortedLabels.size(); ++guessI) {
            placeHolder = this.getPlaceHolder(guessI, sortedLabels.get(guessI));
            ret.write(StringUtils.padLeft(placeHolder, this.leftPadSize));
            U guess = sortedLabels.get(guessI);
            for (U gold : sortedLabels) {
                Integer value = this.get(guess, gold);
                ret.write(StringUtils.padLeft(value.toString(), this.delimPadSize));
            }
            ret.write(StringUtils.padLeft(this.guessMarginal(guess).toString(), this.delimPadSize));
            ret.write("\n");
        }
        ret.write(StringUtils.padLeft("Marg. (Gold)", this.leftPadSize));
        for (U gold : sortedLabels) {
            ret.write(StringUtils.padLeft(this.goldMarginal(gold).toString(), this.delimPadSize));
        }
        ret.write("\n\n");
        for (int labelI = 0; labelI < sortedLabels.size(); ++labelI) {
            U classLabel = sortedLabels.get(labelI);
            String placeHolder2 = this.getPlaceHolder(labelI, classLabel);
            ret.write(StringUtils.padLeft(placeHolder2, this.leftPadSize));
            if (!this.useRealLabels) {
                ret.write(" = ");
                ret.write(classLabel.toString());
            }
            ret.write(StringUtils.padLeft("", this.delimPadSize));
            Contingency contingency = this.getContingency(classLabel);
            ret.write(contingency.toString());
            ret.write("\n");
        }
        return ret.toString();
    }

    public class Contingency {
        private double tp = 0.0;
        private double fp = 0.0;
        private double tn = 0.0;
        private double fn = 0.0;
        private double prec = 0.0;
        private double recall = 0.0;
        private double spec = 0.0;
        private double f1 = 0.0;

        public Contingency(int tp_, int fp_, int tn_, int fn_) {
            this.tp = tp_;
            this.fp = fp_;
            this.tn = tn_;
            this.fn = fn_;
            this.prec = this.tp / (this.tp + this.fp);
            this.recall = this.tp / (this.tp + this.fn);
            this.spec = this.tn / (this.fp + this.tn);
            this.f1 = 2.0 * this.prec * this.recall / (this.prec + this.recall);
        }

        public String toString() {
            return StringUtils.join(Arrays.asList("prec=" + (this.tp + this.fp > 0.0 ? ConfusionMatrix.this.format.format(this.prec) : "n/a"), "recall=" + (this.tp + this.fn > 0.0 ? ConfusionMatrix.this.format.format(this.recall) : "n/a"), "spec=" + (this.fp + this.tn > 0.0 ? ConfusionMatrix.this.format.format(this.spec) : "n/a"), "f1=" + (this.prec + this.recall > 0.0 ? ConfusionMatrix.this.format.format(this.f1) : "n/a")), ", ");
        }

        public double f1() {
            return this.f1;
        }

        public double precision() {
            return this.prec;
        }

        public double recall() {
            return this.recall;
        }

        public double spec() {
            return this.spec;
        }
    }
}

