/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.FactorTable;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.ListeningSequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.GeneralizedCounter;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class CRFCliqueTree<E>
implements ListeningSequenceModel {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CRFCliqueTree.class);
    private final FactorTable[] factorTables;
    private final double z;
    private final Index<E> classIndex;
    private final E backgroundSymbol;
    private final int backgroundIndex;
    private final int windowSize;
    private final int numClasses;
    private final int[] possibleValues;

    public CRFCliqueTree(FactorTable[] factorTables, Index<E> classIndex, E backgroundSymbol) {
        this(factorTables, classIndex, backgroundSymbol, factorTables[0].totalMass());
    }

    CRFCliqueTree(FactorTable[] factorTables, Index<E> classIndex, E backgroundSymbol, double z) {
        this.factorTables = factorTables;
        this.z = z;
        this.classIndex = classIndex;
        this.backgroundSymbol = backgroundSymbol;
        this.backgroundIndex = classIndex.indexOf(backgroundSymbol);
        this.windowSize = factorTables[0].windowSize();
        this.numClasses = classIndex.size();
        this.possibleValues = new int[this.numClasses];
        for (int i = 0; i < this.numClasses; ++i) {
            this.possibleValues[i] = i;
        }
    }

    public FactorTable[] getFactorTables() {
        return this.factorTables;
    }

    public Index<E> classIndex() {
        return this.classIndex;
    }

    @Override
    public int length() {
        return this.factorTables.length;
    }

    @Override
    public int leftWindow() {
        return this.windowSize;
    }

    @Override
    public int rightWindow() {
        return 0;
    }

    @Override
    public int[] getPossibleValues(int position) {
        return this.possibleValues;
    }

    @Override
    public double scoreOf(int[] sequence, int pos) {
        return this.scoresOf(sequence, pos)[sequence[pos]];
    }

    @Override
    public double[] scoresOf(int[] sequence, int position) {
        int i;
        if (position >= this.factorTables.length) {
            throw new RuntimeException("Index out of bounds: " + position);
        }
        double[] probThisGivenPrev = new double[this.numClasses];
        double[] probNextGivenThis = new double[this.numClasses];
        int prevLength = this.windowSize - 1;
        int[] prev = new int[prevLength + 1];
        for (i = 0; i < prevLength - position; ++i) {
            prev[i] = this.classIndex.indexOf(this.backgroundSymbol);
        }
        while (i < prevLength) {
            prev[i] = sequence[position - prevLength + i];
            ++i;
        }
        for (int label = 0; label < this.numClasses; ++label) {
            prev[prev.length - 1] = label;
            probThisGivenPrev[label] = this.factorTables[position].unnormalizedLogProb(prev);
        }
        int nextLength = this.windowSize - 1;
        if (position + nextLength >= this.length()) {
            nextLength = this.length() - position - 1;
        }
        FactorTable nextFactorTable = this.factorTables[position + nextLength];
        if (nextLength != this.windowSize - 1) {
            for (int j = 0; j < this.windowSize - 1 - nextLength; ++j) {
                nextFactorTable = nextFactorTable.sumOutFront();
            }
        }
        if (nextLength == 0) {
            Arrays.fill(probNextGivenThis, 1.0);
        } else {
            int[] next = new int[nextLength];
            System.arraycopy(sequence, position + 1, next, 0, nextLength);
            for (int label = 0; label < this.numClasses; ++label) {
                probNextGivenThis[label] = nextFactorTable.unnormalizedConditionalLogProbGivenFirst(label, next);
            }
        }
        return ArrayMath.pairwiseAdd(probThisGivenPrev, probNextGivenThis);
    }

    @Override
    public double scoreOf(int[] sequence) {
        int[] given = new int[this.window() - 1];
        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
        double logProb = 0.0;
        int length = this.length();
        for (int i = 0; i < length; ++i) {
            int label = sequence[i];
            logProb += this.condLogProbGivenPrevious(i, label, given);
            System.arraycopy(given, 1, given, 0, given.length - 1);
            given[given.length - 1] = label;
        }
        return logProb;
    }

    public int window() {
        return this.windowSize;
    }

    public int getNumClasses() {
        return this.numClasses;
    }

    public double totalMass() {
        return this.z;
    }

    public int backgroundIndex() {
        return this.backgroundIndex;
    }

    public E backgroundSymbol() {
        return this.backgroundSymbol;
    }

    public double[][] logProbTable() {
        double[][] result = new double[this.length()][this.classIndex.size()];
        for (int i = 0; i < this.length(); ++i) {
            result[i] = new double[this.classIndex.size()];
            for (int j = 0; j < this.classIndex.size(); ++j) {
                result[i][j] = this.logProb(i, j);
            }
        }
        return result;
    }

    public double logProbStartPos() {
        double u = this.factorTables[0].unnormalizedLogProbFront(this.backgroundIndex);
        return u - this.z;
    }

    public double logProb(int position, int label) {
        double u = this.factorTables[position].unnormalizedLogProbEnd(label);
        return u - this.z;
    }

    public double prob(int position, int label) {
        return Math.exp(this.logProb(position, label));
    }

    public double logProb(int position, E label) {
        return this.logProb(position, this.classIndex.indexOf(label));
    }

    public double prob(int position, E label) {
        return Math.exp(this.logProb(position, label));
    }

    public double[] probsToDoubleArr(int position) {
        double[] probs = new double[this.classIndex.size()];
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            probs[i] = this.prob(position, i);
        }
        return probs;
    }

    public double[] logProbsToDoubleArr(int position) {
        double[] probs = new double[this.classIndex.size()];
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            probs[i] = this.logProb(position, i);
        }
        return probs;
    }

    public Counter<E> probs(int position) {
        ClassicCounter<E> c = new ClassicCounter<E>();
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            E label = this.classIndex.get(i);
            c.incrementCount(label, this.prob(position, i));
        }
        return c;
    }

    public Counter<E> logProbs(int position) {
        ClassicCounter<E> c = new ClassicCounter<E>();
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            E label = this.classIndex.get(i);
            c.incrementCount(label, this.logProb(position, i));
        }
        return c;
    }

    public double logProb(int position, int[] labels) {
        if (labels.length < this.windowSize) {
            return this.factorTables[position].unnormalizedLogProbEnd(labels) - this.z;
        }
        if (labels.length == this.windowSize) {
            return this.factorTables[position].unnormalizedLogProb(labels) - this.z;
        }
        int[] l = new int[this.windowSize];
        System.arraycopy(labels, 0, l, 0, l.length);
        int position1 = position - labels.length + this.windowSize;
        double p = this.factorTables[position1].unnormalizedLogProb(l) - this.z;
        l = new int[this.windowSize - 1];
        System.arraycopy(labels, 1, l, 0, l.length);
        ++position1;
        for (int i = this.windowSize; i < labels.length; ++i) {
            p += this.condLogProbGivenPrevious(position1++, labels[i], l);
            System.arraycopy(l, 1, l, 0, l.length - 1);
            l[this.windowSize - 2] = labels[i];
        }
        return p;
    }

    public double prob(int position, int[] labels) {
        return Math.exp(this.logProb(position, labels));
    }

    public double logProb(int position, E[] labels) {
        return this.logProb(position, this.objectArrayToIntArray(labels));
    }

    public double prob(int position, E[] labels) {
        return Math.exp(this.logProb(position, labels));
    }

    /*
     * Unable to fully structure code
     */
    public GeneralizedCounter<E> logProbs(int position, int window) {
        gc = new GeneralizedCounter<E>(window);
        labels = new int[window];
        block0: while (true) {
            labelsList = this.intArrayToListE(labels);
            gc.incrementCount(labelsList, this.logProb(position, labels));
            i = 0;
            while (true) {
                if (i >= labels.length) continue block0;
                v0 = i;
                labels[v0] = labels[v0] + 1;
                if (labels[i] >= this.numClasses) ** break;
                continue block0;
                if (i == labels.length - 1) break block0;
                labels[i] = 0;
                ++i;
            }
            break;
        }
        return gc;
    }

    /*
     * Unable to fully structure code
     */
    public GeneralizedCounter<E> probs(int position, int window) {
        gc = new GeneralizedCounter<E>(window);
        labels = new int[window];
        block0: while (true) {
            labelsList = this.intArrayToListE(labels);
            gc.incrementCount(labelsList, this.prob(position, labels));
            i = 0;
            while (true) {
                if (i >= labels.length) continue block0;
                v0 = i;
                labels[v0] = labels[v0] + 1;
                if (labels[i] >= this.numClasses) ** break;
                continue block0;
                if (i == labels.length - 1) break block0;
                labels[i] = 0;
                ++i;
            }
            break;
        }
        return gc;
    }

    private int[] objectArrayToIntArray(E[] os) {
        int[] is = new int[os.length];
        for (int i = 0; i < os.length; ++i) {
            is[i] = this.classIndex.indexOf(os[i]);
        }
        return is;
    }

    private List<E> intArrayToListE(int[] is) {
        ArrayList<E> os = new ArrayList<E>(is.length);
        for (int i : is) {
            os.add(this.classIndex.get(i));
        }
        return os;
    }

    public double condLogProbGivenPrevious(int position, int label, int[] prevLabels) {
        if (prevLabels.length + 1 == this.windowSize) {
            return this.factorTables[position].conditionalLogProbGivenPrevious(prevLabels, label);
        }
        if (prevLabels.length + 1 < this.windowSize) {
            FactorTable ft = this.factorTables[position].sumOutFront();
            while (ft.windowSize() > prevLabels.length + 1) {
                ft = ft.sumOutFront();
            }
            return ft.conditionalLogProbGivenPrevious(prevLabels, label);
        }
        int[] p = new int[this.windowSize - 1];
        System.arraycopy(prevLabels, prevLabels.length - p.length, p, 0, p.length);
        return this.factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }

    public double condLogProbGivenPrevious(int position, E label, E[] prevLabels) {
        return this.condLogProbGivenPrevious(position, this.classIndex.indexOf(label), this.objectArrayToIntArray(prevLabels));
    }

    public double condProbGivenPrevious(int position, int label, int[] prevLabels) {
        return Math.exp(this.condLogProbGivenPrevious(position, label, prevLabels));
    }

    public double condProbGivenPrevious(int position, E label, E[] prevLabels) {
        return Math.exp(this.condLogProbGivenPrevious(position, label, prevLabels));
    }

    public Counter<E> condLogProbsGivenPrevious(int position, int[] prevlabels) {
        ClassicCounter<E> c = new ClassicCounter<E>();
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            E label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenPrevious(position, i, prevlabels));
        }
        return c;
    }

    public Counter<E> condLogProbsGivenPrevious(int position, E[] prevlabels) {
        ClassicCounter<E> c = new ClassicCounter<E>();
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            E label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenPrevious(position, label, prevlabels));
        }
        return c;
    }

    public double condLogProbGivenNext(int position, int label, int[] nextLabels) {
        position += nextLabels.length;
        if (nextLabels.length + 1 == this.windowSize) {
            return this.factorTables[position].conditionalLogProbGivenNext(nextLabels, label);
        }
        if (nextLabels.length + 1 < this.windowSize) {
            FactorTable ft = this.factorTables[position].sumOutFront();
            while (ft.windowSize() > nextLabels.length + 1) {
                ft = ft.sumOutFront();
            }
            return ft.conditionalLogProbGivenPrevious(nextLabels, label);
        }
        int[] p = new int[this.windowSize - 1];
        System.arraycopy(nextLabels, 0, p, 0, p.length);
        return this.factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }

    public double condLogProbGivenNext(int position, E label, E[] nextLabels) {
        return this.condLogProbGivenNext(position, this.classIndex.indexOf(label), this.objectArrayToIntArray(nextLabels));
    }

    public double condProbGivenNext(int position, int label, int[] nextLabels) {
        return Math.exp(this.condLogProbGivenNext(position, label, nextLabels));
    }

    public double condProbGivenNext(int position, E label, E[] nextLabels) {
        return Math.exp(this.condLogProbGivenNext(position, label, nextLabels));
    }

    public Counter<E> condLogProbsGivenNext(int position, int[] nextlabels) {
        ClassicCounter<E> c = new ClassicCounter<E>();
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            E label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenNext(position, i, nextlabels));
        }
        return c;
    }

    public Counter<E> condLogProbsGivenNext(int position, E[] nextlabels) {
        ClassicCounter<E> c = new ClassicCounter<E>();
        int sz = this.classIndex.size();
        for (int i = 0; i < sz; ++i) {
            E label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenNext(position, label, nextlabels));
        }
        return c;
    }

    public static <E> CRFCliqueTree<E> getCalibratedCliqueTree(int[][][] data, List<Index<CRFLabel>> labelIndices, int numClasses, Index<E> classIndex, E backgroundSymbol, CliquePotentialFunction cliquePotentialFunc, double[][][] featureVals) {
        int i;
        FactorTable[] factorTables = new FactorTable[data.length];
        FactorTable[] messages = new FactorTable[data.length - 1];
        for (i = 0; i < data.length; ++i) {
            double[][] featureValByCliqueSize = null;
            if (featureVals != null) {
                featureValByCliqueSize = featureVals[i];
            }
            factorTables[i] = CRFCliqueTree.getFactorTable(data[i], labelIndices, numClasses, cliquePotentialFunc, featureValByCliqueSize, i);
            if (i <= 0) continue;
            messages[i - 1] = factorTables[i - 1].sumOutFront();
            factorTables[i].multiplyInFront(messages[i - 1]);
        }
        for (i = factorTables.length - 2; i >= 0; --i) {
            FactorTable summedOut = factorTables[i + 1].sumOutEnd();
            summedOut.divideBy(messages[i]);
            factorTables[i].multiplyInEnd(summedOut);
        }
        return new CRFCliqueTree<E>(factorTables, classIndex, backgroundSymbol);
    }

    public static <E> CRFCliqueTree<E> getCalibratedCliqueTree(double[] weights, double wscale, int[][] weightIndices, int[][][] data, List<Index<CRFLabel>> labelIndices, int numClasses, Index<E> classIndex, E backgroundSymbol) {
        int i;
        FactorTable[] factorTables = new FactorTable[data.length];
        FactorTable[] messages = new FactorTable[data.length - 1];
        for (i = 0; i < data.length; ++i) {
            factorTables[i] = CRFCliqueTree.getFactorTable(weights, wscale, weightIndices, data[i], labelIndices, numClasses);
            if (i <= 0) continue;
            messages[i - 1] = factorTables[i - 1].sumOutFront();
            factorTables[i].multiplyInFront(messages[i - 1]);
        }
        for (i = factorTables.length - 2; i >= 0; --i) {
            FactorTable summedOut = factorTables[i + 1].sumOutEnd();
            summedOut.divideBy(messages[i]);
            factorTables[i].multiplyInEnd(summedOut);
        }
        return new CRFCliqueTree<E>(factorTables, classIndex, backgroundSymbol);
    }

    private static FactorTable getFactorTable(double[] weights, double wScale, int[][] weightIndices, int[][] data, List<Index<CRFLabel>> labelIndices, int numClasses) {
        FactorTable factorTable = null;
        int sz = labelIndices.size();
        for (int j = 0; j < sz; ++j) {
            Index<CRFLabel> labelIndex = labelIndices.get(j);
            FactorTable ft = new FactorTable(numClasses, j + 1);
            int liSize = labelIndex.size();
            for (int k = 0; k < liSize; ++k) {
                int[] label = labelIndex.get(k).getLabel();
                double weight = 0.0;
                for (int m = 0; m < data[j].length; ++m) {
                    int wi = weightIndices[data[j][m]][k];
                    weight += wScale * weights[wi];
                }
                ft.setValue(label, weight);
            }
            if (j > 0) {
                ft.multiplyInEnd(factorTable);
            }
            factorTable = ft;
        }
        return factorTable;
    }

    static FactorTable getFactorTable(int[][] data, List<Index<CRFLabel>> labelIndices, int numClasses, CliquePotentialFunction cliquePotentialFunc, double[][] featureValByCliqueSize, int posInSent) {
        FactorTable factorTable = null;
        int sz = labelIndices.size();
        for (int j = 0; j < sz; ++j) {
            Index<CRFLabel> labelIndex = labelIndices.get(j);
            FactorTable ft = new FactorTable(numClasses, j + 1);
            double[] featureVal = null;
            if (featureValByCliqueSize != null) {
                featureVal = featureValByCliqueSize[j];
            }
            int liSize = labelIndex.size();
            for (int k = 0; k < liSize; ++k) {
                int[] label = labelIndex.get(k).getLabel();
                double cliquePotential = cliquePotentialFunc.computeCliquePotential(j + 1, k, data[j], featureVal, posInSent);
                ft.setValue(label, cliquePotential);
            }
            if (j > 0) {
                ft.multiplyInEnd(factorTable);
            }
            factorTable = ft;
        }
        return factorTable;
    }

    public double[] getConditionalDistribution(int[] sequence, int position) {
        double[] result = this.scoresOf(sequence, position);
        ArrayMath.logNormalize(result);
        result = ArrayMath.exp(result);
        return result;
    }

    @Override
    public void updateSequenceElement(int[] sequence, int pos, int oldVal) {
    }

    @Override
    public void setInitialSequence(int[] sequence) {
    }

    public int getNumValues() {
        return this.numClasses;
    }
}

