/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.machinereading;

import edu.stanford.nlp.ie.machinereading.ResultsPrinter;
import edu.stanford.nlp.ie.machinereading.structure.AnnotationUtils;
import edu.stanford.nlp.ie.machinereading.structure.RelationMention;
import edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class RelationExtractorResultsPrinter
extends ResultsPrinter {
    protected boolean createUnrelatedRelations;
    protected final RelationMentionFactory relationMentionFactory;
    private static final int MAX_LABEL_LENGTH = 31;

    public RelationExtractorResultsPrinter(RelationMentionFactory factory) {
        this(factory, true);
    }

    public RelationExtractorResultsPrinter() {
        this(new RelationMentionFactory(), true);
    }

    public RelationExtractorResultsPrinter(boolean createUnrelatedRelations) {
        this(new RelationMentionFactory(), createUnrelatedRelations);
    }

    public RelationExtractorResultsPrinter(RelationMentionFactory factory, boolean createUnrelatedRelations) {
        this.createUnrelatedRelations = createUnrelatedRelations;
        this.relationMentionFactory = factory;
    }

    @Override
    public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
        ResultsPrinter.align(goldStandard, extractorOutput);
        assert (this.relationMentionFactory != null) : "ERROR: RelationExtractorResultsPrinter.relationMentionFactory cannot be null in printResults!";
        ClassicCounter<Pair<String, String>> results = new ClassicCounter<Pair<String, String>>();
        ClassicCounter<String> labelCount = new ClassicCounter<String>();
        for (int goldSentenceIndex = 0; goldSentenceIndex < goldStandard.size(); ++goldSentenceIndex) {
            for (RelationMention goldRelation : AnnotationUtils.getAllRelations(this.relationMentionFactory, goldStandard.get(goldSentenceIndex), this.createUnrelatedRelations)) {
                CoreMap extractorSentence = extractorOutput.get(goldSentenceIndex);
                List<RelationMention> extractorRelations = AnnotationUtils.getRelations(this.relationMentionFactory, extractorSentence, goldRelation.getArg(0), goldRelation.getArg(1));
                labelCount.incrementCount(goldRelation.getType());
                for (RelationMention extractorRelation : extractorRelations) {
                    results.incrementCount(new Pair<String, String>(extractorRelation.getType(), goldRelation.getType()));
                }
            }
        }
        this.printResultsInternal(pw, results, labelCount);
    }

    private void printResultsInternal(PrintWriter pw, Counter<Pair<String, String>> results, ClassicCounter<String> labelCount) {
        ClassicCounter<String> correct = new ClassicCounter<String>();
        ClassicCounter<String> predictionCount = new ClassicCounter<String>();
        boolean countGoldLabels = false;
        if (labelCount == null) {
            labelCount = new ClassicCounter();
            countGoldLabels = true;
        }
        for (Pair<String, String> predictedActual : results.keySet()) {
            String predicted = (String)predictedActual.first;
            String actual = (String)predictedActual.second;
            if (predicted.equals(actual)) {
                correct.incrementCount(actual, results.getCount(predictedActual));
            }
            predictionCount.incrementCount(predicted, results.getCount(predictedActual));
            if (!countGoldLabels) continue;
            labelCount.incrementCount(actual, results.getCount(predictedActual));
        }
        DecimalFormat formatter = new DecimalFormat();
        formatter.setMaximumFractionDigits(1);
        formatter.setMinimumFractionDigits(1);
        double totalCount = 0.0;
        double totalCorrect = 0.0;
        double totalPredicted = 0.0;
        pw.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
        ArrayList<String> labels = new ArrayList<String>(labelCount.keySet());
        Collections.sort(labels);
        for (String label : labels) {
            double numcorrect = correct.getCount(label);
            double predicted = predictionCount.getCount(label);
            double trueCount = labelCount.getCount(label);
            double precision = predicted > 0.0 ? numcorrect / predicted : 0.0;
            double recall = numcorrect / trueCount;
            double f = precision + recall > 0.0 ? 2.0 * precision * recall / (precision + recall) : 0.0;
            pw.println(StringUtils.padOrTrim(label, 31) + "\t" + numcorrect + "\t" + predicted + "\t" + trueCount + "\t" + formatter.format(precision * 100.0) + "\t" + formatter.format(100.0 * recall) + "\t" + formatter.format(100.0 * f));
            if (RelationMention.isUnrelatedLabel(label)) continue;
            totalCount += trueCount;
            totalCorrect += numcorrect;
            totalPredicted += predicted;
        }
        double precision = totalPredicted > 0.0 ? totalCorrect / totalPredicted : 0.0;
        double recall = totalCorrect / totalCount;
        double f = totalPredicted > 0.0 && totalCorrect > 0.0 ? 2.0 * precision * recall / (precision + recall) : 0.0;
        pw.println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t" + totalCount + "\t" + formatter.format(100.0 * precision) + "\t" + formatter.format(100.0 * recall) + "\t" + formatter.format(100.0 * f));
    }

    @Override
    public void printResultsUsingLabels(PrintWriter pw, List<String> goldStandard, List<String> extractorOutput) {
        ClassicCounter<Pair<String, String>> results = new ClassicCounter<Pair<String, String>>();
        assert (goldStandard.size() == extractorOutput.size());
        for (int i = 0; i < goldStandard.size(); ++i) {
            results.incrementCount(new Pair<String, String>(extractorOutput.get(i), goldStandard.get(i)));
        }
        this.printResultsInternal(pw, results, null);
    }
}

