package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.trees.BobChrisTreeNormalizer;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.LabeledScoredTreeFactory;
import edu.stanford.nlp.trees.PennTreeReader;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeReader;
import edu.stanford.nlp.trees.TreeReaderFactory;
import edu.stanford.nlp.trees.TreeVisitor;
import edu.stanford.nlp.util.Pair;
import java.io.Reader;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:lib/stanford-postagger.jar:edu/stanford/nlp/parser/lexparser/SisterAnnotationStats.class */
public class SisterAnnotationStats implements TreeVisitor {
    public static final boolean DO_TAGS = true;
    private Map nodeRules = new HashMap();
    private Map leftRules = new HashMap();
    private Map rightRules = new HashMap();
    public static final double[] CUTOFFS = {250.0d, 500.0d, 1000.0d, 1500.0d};
    public static final double SUPPCUTOFF = 100.0d;

    @Override // edu.stanford.nlp.trees.TreeVisitor
    public void visitTree(Tree tree) {
        recurse(tree, null);
    }

    public void recurse(Tree tree, Tree tree2) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.isPreTerminal()) {
        }
        if (tree2 != null && !tree.label().value().equals("ROOT")) {
            sisterCounters(tree, tree2);
        }
        for (Tree tree3 : tree.children()) {
            recurse(tree3, tree);
        }
    }

    public static List<String> leftSisterLabels(Tree tree, Tree tree2) {
        ArrayList arrayList = new ArrayList();
        if (tree2 == null) {
            return arrayList;
        }
        Tree[] children = tree2.children();
        for (int i = 0; i < children.length && !children[i].equals(tree); i++) {
            arrayList.add(0, children[i].label().value());
        }
        return arrayList;
    }

    public static List<String> rightSisterLabels(Tree tree, Tree tree2) {
        ArrayList arrayList = new ArrayList();
        if (tree2 == null) {
            return arrayList;
        }
        Tree[] children = tree2.children();
        for (int length = children.length - 1; length >= 0 && !children[length].equals(tree); length--) {
            arrayList.add(children[length].label().value());
        }
        return arrayList;
    }

    public static List<String> kidLabels(Tree tree) {
        Tree[] children = tree.children();
        ArrayList arrayList = new ArrayList(children.length);
        for (Tree tree2 : children) {
            arrayList.add(tree2.label().value());
        }
        return arrayList;
    }

    protected void sisterCounters(Tree tree, Tree tree2) {
        List<String> kidLabels = kidLabels(tree);
        List<String> leftSisterLabels = leftSisterLabels(tree, tree2);
        List<String> rightSisterLabels = rightSisterLabels(tree, tree2);
        String value = tree.label().value();
        if (!this.nodeRules.containsKey(value)) {
            this.nodeRules.put(value, new ClassicCounter());
        }
        if (!this.rightRules.containsKey(value)) {
            this.rightRules.put(value, new HashMap());
        }
        if (!this.leftRules.containsKey(value)) {
            this.leftRules.put(value, new HashMap());
        }
        ((ClassicCounter) this.nodeRules.get(value)).incrementCount(kidLabels);
        sideCounters(value, kidLabels, leftSisterLabels, this.leftRules);
        sideCounters(value, kidLabels, rightSisterLabels, this.rightRules);
    }

    protected void sideCounters(String str, List list, List list2, Map map) {
        Iterator it = list2.iterator();
        while (it.hasNext()) {
            String str2 = (String) it.next();
            if (!((Map) map.get(str)).containsKey(str2)) {
                ((Map) map.get(str)).put(str2, new ClassicCounter());
            }
            ((ClassicCounter) ((HashMap) map.get(str)).get(str2)).incrementCount(list);
        }
    }

    public void printStats() {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(2);
        StringBuffer[] stringBufferArr = new StringBuffer[CUTOFFS.length];
        for (int i = 0; i < CUTOFFS.length; i++) {
            stringBufferArr[i] = new StringBuffer("  private static String[] sisterSplit" + (i + 1) + " = new String[] {");
        }
        ArrayList arrayList = new ArrayList();
        for (String str : this.nodeRules.keySet()) {
            ArrayList arrayList2 = new ArrayList();
            ClassicCounter classicCounter = (ClassicCounter) this.nodeRules.get(str);
            System.out.println("Node " + str + " support is " + classicCounter.totalCount());
            for (String str2 : ((HashMap) this.leftRules.get(str)).keySet()) {
                ClassicCounter classicCounter2 = (ClassicCounter) ((HashMap) this.leftRules.get(str)).get(str2);
                double d = classicCounter2.totalCount();
                double klDivergence = Counters.klDivergence(classicCounter2, classicCounter);
                String str3 = str + "=l=" + str2;
                System.out.println("KL(" + str3 + "||" + str + ") = " + numberInstance.format(klDivergence) + "\tsupport(" + str2 + ") = " + d);
                arrayList2.add(new Pair(str3, new Double(klDivergence * d)));
                arrayList.add(new Pair(str3, new Double(klDivergence * d)));
            }
            for (String str4 : ((HashMap) this.rightRules.get(str)).keySet()) {
                ClassicCounter classicCounter3 = (ClassicCounter) ((HashMap) this.rightRules.get(str)).get(str4);
                double d2 = classicCounter3.totalCount();
                double klDivergence2 = Counters.klDivergence(classicCounter3, classicCounter);
                String str5 = str + "=r=" + str4;
                System.out.println("KL(" + str5 + "||" + str + ") = " + numberInstance.format(klDivergence2) + "\tsupport(" + str4 + ") = " + d2);
                arrayList2.add(new Pair(str5, new Double(klDivergence2 * d2)));
                arrayList.add(new Pair(str5, new Double(klDivergence2 * d2)));
            }
            System.out.println("----");
            System.out.println("Sorted descending support * KL");
            Collections.sort(arrayList2, new Comparator() { // from class: edu.stanford.nlp.parser.lexparser.SisterAnnotationStats.1
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    return ((Double) ((Pair) obj2).second()).compareTo((Double) ((Pair) obj).second());
                }
            });
            int size = arrayList2.size();
            for (int i2 = 0; i2 < size; i2++) {
                Pair pair = (Pair) arrayList2.get(i2);
                double doubleValue = ((Double) pair.second()).doubleValue();
                System.out.println(pair.first() + ": " + numberInstance.format(doubleValue));
                if (doubleValue >= CUTOFFS[0]) {
                    for (int i3 = 0; i3 < CUTOFFS.length; i3++) {
                        if (doubleValue >= CUTOFFS[i3]) {
                        }
                    }
                }
            }
            System.out.println();
        }
        Collections.sort(arrayList, new Comparator() { // from class: edu.stanford.nlp.parser.lexparser.SisterAnnotationStats.2
            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return ((Double) ((Pair) obj2).second()).compareTo((Double) ((Pair) obj).second());
            }
        });
        int size2 = arrayList.size();
        for (int i4 = 0; i4 < size2; i4++) {
            Pair pair2 = (Pair) arrayList.get(i4);
            System.out.println(pair2.first() + ": " + numberInstance.format(((Double) pair2.second()).doubleValue()));
        }
        System.out.println();
        System.out.println("  // Automatically generated by SisterAnnotationStats -- preferably don't edit");
        int length = CUTOFFS.length - 1;
        int i5 = 0;
        while (i5 < arrayList.size()) {
            Pair pair3 = (Pair) arrayList.get(i5);
            if (((Double) pair3.second()).doubleValue() >= CUTOFFS[length]) {
                stringBufferArr[length].append("\"").append(pair3.first());
                stringBufferArr[length].append("\",");
            } else {
                if (length == 0) {
                    break;
                }
                length--;
                i5--;
            }
            i5++;
        }
        for (int i6 = 0; i6 < CUTOFFS.length; i6++) {
            int length2 = stringBufferArr[i6].length();
            stringBufferArr[i6].replace(length2 - 2, length2, "};");
            System.out.println(stringBufferArr[i6]);
        }
        System.out.print("  public static String[] sisterSplit = ");
        for (int length3 = CUTOFFS.length; length3 > 0; length3--) {
            if (length3 == 1) {
                System.out.print("sisterSplit1");
            } else {
                System.out.print("selectiveSisterSplit" + length3 + " ? sisterSplit" + length3 + " : (");
            }
        }
        for (int length4 = CUTOFFS.length; length4 >= 0; length4--) {
            System.out.print(")");
        }
        System.out.println(";");
    }

    public static void main(String[] strArr) {
        ClassicCounter classicCounter = new ClassicCounter();
        classicCounter.setCount("A", 0.0d);
        classicCounter.setCount("B", 1.0d);
        System.out.println("KL Divergence: " + Counters.klDivergence(classicCounter, classicCounter));
        String str = strArr.length > 1 ? strArr[1] : "UTF-8";
        if (strArr.length < 1) {
            System.out.println("Usage: ParentAnnotationStats treebankPath");
            return;
        }
        SisterAnnotationStats sisterAnnotationStats = new SisterAnnotationStats();
        DiskTreebank diskTreebank = new DiskTreebank(new TreeReaderFactory() { // from class: edu.stanford.nlp.parser.lexparser.SisterAnnotationStats.3
            @Override // edu.stanford.nlp.trees.TreeReaderFactory
            public TreeReader newTreeReader(Reader reader) {
                return new PennTreeReader(reader, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer());
            }
        }, str);
        diskTreebank.loadPath(strArr[0]);
        diskTreebank.apply(sisterAnnotationStats);
        sisterAnnotationStats.printStats();
    }
}
