/*
 * Decompiled with CFR 0.152.
 */
package edu.northwestern.at.morphadorner.corpuslinguistics.postagger.transitionmatrix;

import edu.northwestern.at.utils.CompoundKey;
import edu.northwestern.at.utils.Formatters;
import edu.northwestern.at.utils.IsCloseableObject;
import edu.northwestern.at.utils.Map2D;
import edu.northwestern.at.utils.Map2DFactory;
import edu.northwestern.at.utils.Map3D;
import edu.northwestern.at.utils.Map3DFactory;
import edu.northwestern.at.utils.MapFactory;
import edu.northwestern.at.utils.StringUtils;
import edu.northwestern.at.utils.UnicodeReader;
import edu.northwestern.at.utils.logger.DummyLogger;
import edu.northwestern.at.utils.logger.Logger;
import edu.northwestern.at.utils.logger.UsesLogger;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.Writer;
import java.net.URL;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.zip.GZIPInputStream;

public class TransitionMatrix
extends IsCloseableObject
implements UsesLogger {
    protected static boolean debug = true;
    protected Map<String, Integer> unigramCountMap = MapFactory.createNewMap();
    protected Map2D<String, String, Integer> bigramCountMap = Map2DFactory.createNewMap2D();
    protected Map3D<String, String, String, Integer> trigramCountMap = Map3DFactory.createNewMap3D();
    protected Map<String, Double> unigramProbMap = MapFactory.createNewMap();
    protected Map2D<String, String, Double> bigramProbMap = Map2DFactory.createNewMap2D();
    protected Map3D<String, String, String, Double> trigramProbMap = Map3DFactory.createNewMap3D();
    protected int[] totalNGrams = new int[]{0, 0, 0};
    protected int[] uniqueNGrams = new int[]{0, 0, 0};
    protected int totalWords = 0;
    protected boolean haveProbabilities = false;
    protected double[] bigramWeights = null;
    protected double[] trigramWeights = null;
    protected static final int UNIGRAM = 0;
    protected static final int BIGRAM = 1;
    protected static final int TRIGRAM = 2;
    protected Logger logger = new DummyLogger();

    @Override
    public Logger getLogger() {
        return this.logger;
    }

    @Override
    public void setLogger(Logger logger) {
        this.logger = logger;
    }

    public void incrementCount(String tag, int increment) {
        Integer count = this.unigramCountMap.get(tag);
        int newCount = increment;
        if (count != null) {
            newCount = count + increment;
        } else {
            this.uniqueNGrams[0] = this.uniqueNGrams[0] + 1;
        }
        this.totalNGrams[0] = this.totalNGrams[0] + increment;
        this.unigramCountMap.put(tag, new Integer(newCount));
        this.totalWords += increment;
        this.haveProbabilities = false;
    }

    public void incrementCount(String tag1, String tag2, int increment) {
        Integer count = this.bigramCountMap.get(tag1, tag2);
        int newCount = increment;
        if (count != null) {
            newCount = count + increment;
        } else {
            this.uniqueNGrams[1] = this.uniqueNGrams[1] + 1;
        }
        this.totalNGrams[1] = this.totalNGrams[1] + increment;
        this.bigramCountMap.put(tag1, tag2, new Integer(newCount));
        this.haveProbabilities = false;
    }

    public void incrementCount(String tag1, String tag2, String tag3, int increment) {
        Integer count = this.trigramCountMap.get(tag1, tag2, tag3);
        int newCount = increment;
        if (count != null) {
            newCount = count + increment;
        } else {
            this.uniqueNGrams[2] = this.uniqueNGrams[2] + 1;
        }
        this.totalNGrams[2] = this.totalNGrams[2] + increment;
        this.trigramCountMap.put(tag1, tag2, tag3, new Integer(newCount));
        this.haveProbabilities = false;
    }

    public double safelyDivideCount(int numerator, int denominator) {
        double result = 0.0;
        if ((double)denominator > 0.0) {
            result = (double)numerator / (double)denominator;
        }
        return result;
    }

    public double safelyDivideSmoothedCount(int numerator, int denominator) {
        double result = 0.0;
        if (denominator > 1) {
            result = (double)(numerator - 1) / (double)(denominator - 1);
        }
        return result;
    }

    public void calculateProbabilities() {
        this.computeBigramWeights();
        this.computeTrigramWeights();
        this.haveProbabilities = true;
    }

    protected void computeTrigramWeights() {
        double lambda1 = 0.0;
        double lambda2 = 0.0;
        double lambda3 = 0.0;
        boolean n1 = false;
        boolean n2 = false;
        boolean n3 = false;
        int cnt = 0;
        int nEntries = this.totalNGrams[0];
        Iterator<CompoundKey> iterator = this.trigramCountMap.iterator();
        while (iterator.hasNext()) {
            ++cnt;
            CompoundKey compoundKey = iterator.next();
            Comparable[] keyValues = compoundKey.getKeyValues();
            String[] tags = new String[keyValues.length];
            for (int i = 0; i < tags.length; ++i) {
                tags[i] = keyValues[i].toString();
            }
            int trigramCount = this.getCount(tags[0], tags[1], tags[2]);
            if (trigramCount <= 0) continue;
            double unigramP = this.safelyDivideCount(this.getCount(tags[2]), nEntries);
            double bigramP = this.safelyDivideCount(this.getCount(tags[1], tags[2]), this.getCount(tags[1]));
            double trigramP = this.safelyDivideCount(trigramCount, this.getCount(tags[0], tags[1]));
            this.unigramProbMap.put(tags[2], new Double(unigramP));
            this.bigramProbMap.put(tags[1], tags[2], new Double(bigramP));
            this.trigramProbMap.put(tags[0], tags[1], tags[2], new Double(trigramP));
            double maxP = Math.max(Math.max(unigramP, bigramP), trigramP);
            if (maxP == unigramP) {
                lambda1 += (double)trigramCount;
                continue;
            }
            if (maxP == bigramP) {
                lambda2 += (double)trigramCount;
                continue;
            }
            lambda3 += (double)trigramCount;
        }
        double sum = lambda1 + lambda2 + lambda3;
        if (sum > 0.0) {
            lambda1 /= sum;
            lambda2 /= sum;
            lambda3 /= sum;
        }
        this.trigramWeights = new double[]{lambda1, lambda2, lambda3};
    }

    protected void computeBigramWeights() {
        double lambda1 = 0.0;
        double lambda2 = 0.0;
        int n1 = 0;
        int n2 = 0;
        int nEntries = this.totalNGrams[0];
        for (CompoundKey compoundKey : this.bigramCountMap.keySet()) {
            double bigramP;
            Comparable[] keyValues = compoundKey.getKeyValues();
            String[] tags = new String[keyValues.length];
            for (int i = 0; i < tags.length; ++i) {
                tags[i] = keyValues[i].toString();
            }
            int bigramCount = this.getCount(tags[0], tags[1]);
            if (bigramCount <= 0) continue;
            double unigramP = this.safelyDivideSmoothedCount(this.getCount(tags[1]), nEntries);
            if (unigramP > (bigramP = this.safelyDivideSmoothedCount(bigramCount, this.getCount(tags[0])))) {
                lambda1 += (double)bigramCount;
                ++n1;
                continue;
            }
            lambda2 += (double)bigramCount;
            ++n2;
        }
        double sum = lambda1 + lambda2;
        if (sum > 0.0) {
            lambda1 /= sum;
            lambda2 /= sum;
        }
        lambda1 = 0.03;
        lambda2 = 0.97;
        this.bigramWeights = new double[]{lambda1, lambda2};
    }

    public int getCount(String tag) {
        int result = 0;
        Integer count = this.unigramCountMap.get(tag);
        if (count != null) {
            result = count;
        }
        return result;
    }

    public int getCount(String tag1, String tag2) {
        int result = 0;
        Integer count = this.bigramCountMap.get(tag1, tag2);
        if (count != null) {
            result = count;
        }
        return result;
    }

    public int getCount(String tag1, String tag2, String tag3) {
        int result = 0;
        Integer count = this.trigramCountMap.get(tag1, tag2, tag3);
        if (count != null) {
            result = count;
        }
        return result;
    }

    public double getProbability(String tag) {
        if (!this.haveProbabilities) {
            this.calculateProbabilities();
        }
        Double prob = this.unigramProbMap.get(tag);
        double result = 0.0;
        if (prob != null) {
            result = prob;
        }
        return result;
    }

    public double getProbability(String tag1, String tag2) {
        if (!this.haveProbabilities) {
            this.calculateProbabilities();
        }
        Double prob = this.bigramProbMap.get(tag1, tag2);
        double result = 0.0;
        if (prob != null) {
            result = prob;
        }
        return result;
    }

    public double getProbability(String tag1, String tag2, String tag3) {
        if (!this.haveProbabilities) {
            this.calculateProbabilities();
        }
        Double prob = this.trigramProbMap.get(tag1, tag2, tag3);
        double result = 0.0;
        if (prob != null) {
            result = prob;
        }
        return result;
    }

    public Set<String> rowKeySet() {
        return this.trigramCountMap.rowKeySet();
    }

    public Set<String> columnKeySet() {
        return this.trigramCountMap.columnKeySet();
    }

    public Set<String> sliceKeySet() {
        return this.trigramCountMap.sliceKeySet();
    }

    public int getTotalWordCount() {
        return this.totalWords;
    }

    public void loadTransitionMatrix(URL url, boolean compressed, String encoding, char delimChar) throws IOException {
        InputStream inputStream = url.openStream();
        GZIPInputStream gzipInputStream = null;
        if (compressed) {
            gzipInputStream = new GZIPInputStream(inputStream);
        }
        UnicodeReader reader = new UnicodeReader(compressed ? gzipInputStream : inputStream, encoding);
        this.loadTransitionMatrix(reader, delimChar);
    }

    public void loadTransitionMatrix(URL url, String encoding, char delimChar) throws IOException {
        this.loadTransitionMatrix(url, false, encoding, delimChar);
    }

    public void loadTransitionMatrix(Reader reader, char delimChar) throws IOException {
        String line = "";
        String delim = delimChar + "";
        this.totalNGrams[0] = 0;
        this.totalNGrams[1] = 0;
        this.totalNGrams[2] = 0;
        this.uniqueNGrams[0] = 0;
        this.uniqueNGrams[1] = 0;
        this.uniqueNGrams[2] = 0;
        this.totalWords = 0;
        BufferedReader bufferedReader = new BufferedReader(reader);
        while ((line = bufferedReader.readLine()) != null) {
            String[] tokens = line.split(delim);
            switch (tokens.length) {
                case 2: {
                    int count = Integer.parseInt(tokens[1]);
                    this.incrementCount(tokens[0], count);
                    break;
                }
                case 3: {
                    int count = Integer.parseInt(tokens[2]);
                    this.incrementCount(tokens[0], tokens[1], count);
                    break;
                }
                case 4: {
                    int count = Integer.parseInt(tokens[3]);
                    this.incrementCount(tokens[0], tokens[1], tokens[2], count);
                    break;
                }
            }
        }
        bufferedReader.close();
        this.calculateProbabilities();
    }

    public void displayNGramCounts() {
        this.logger.logDebug("");
        this.logger.logDebug("Transition matrix total ngram counts");
        this.logger.logDebug("");
        this.logger.logDebug("   Unigram    Bigram   Trigram");
        this.logger.logDebug("");
        String s = StringUtils.lpad(Formatters.formatIntegerWithCommas(this.totalNGrams[0]), 10);
        s = s + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.totalNGrams[1]), 10);
        s = s + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.totalNGrams[2]), 10);
        this.logger.logDebug(s);
        this.logger.logDebug("");
        this.logger.logDebug("Transition matrix unique ngram counts");
        this.logger.logDebug("");
        this.logger.logDebug("   Unigram    Bigram   Trigram");
        this.logger.logDebug("");
        s = StringUtils.lpad(Formatters.formatIntegerWithCommas(this.uniqueNGrams[0]), 10);
        s = s + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.uniqueNGrams[1]), 10);
        s = s + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.uniqueNGrams[2]), 10);
        this.logger.logDebug(s);
    }

    public void saveTransitionMatrix(String transitionFileName, String encoding, char delimChar) throws IOException {
        this.displayNGramCounts();
        FileOutputStream fileOutputStream = new FileOutputStream(transitionFileName, false);
        OutputStreamWriter writer = new OutputStreamWriter((OutputStream)fileOutputStream, encoding);
        this.saveTransitionMatrix(writer, delimChar);
    }

    public void saveTransitionMatrix(Writer writer, char delimChar) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(writer);
        Object[] rowTags = this.unigramCountMap.keySet().toArray(new String[0]);
        Arrays.sort(rowTags);
        Set<String> columnSet = this.bigramCountMap.columnKeySet();
        Object[] columnTags = columnSet.toArray(new String[0]);
        Arrays.sort(columnTags);
        Set<String> sliceSet = this.trigramCountMap.sliceKeySet();
        Object[] sliceTags = sliceSet.toArray(new String[0]);
        Arrays.sort(sliceTags);
        for (int i = 0; i < rowTags.length; ++i) {
            int count = this.getCount((String)rowTags[i]);
            if (count > 0) {
                bufferedWriter.write((String)rowTags[i] + delimChar + count);
                bufferedWriter.newLine();
            }
            for (int j = 0; j < columnTags.length; ++j) {
                count = this.getCount((String)rowTags[i], (String)columnTags[j]);
                if (count > 0) {
                    bufferedWriter.write((String)rowTags[i] + delimChar + (String)columnTags[j] + delimChar + count);
                    bufferedWriter.newLine();
                }
                for (int k = 0; k < sliceTags.length; ++k) {
                    if (sliceTags.equals("slice") || (count = this.getCount((String)rowTags[i], (String)columnTags[j], (String)sliceTags[k])) <= 0) continue;
                    bufferedWriter.write((String)rowTags[i] + delimChar + (String)columnTags[j] + delimChar + (String)sliceTags[k] + delimChar + count);
                    bufferedWriter.newLine();
                }
            }
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public double[] getBigramWeights() {
        if (this.bigramWeights == null) {
            this.calculateProbabilities();
        }
        return this.bigramWeights;
    }

    public double[] getTrigramWeights() {
        if (this.trigramWeights == null) {
            this.calculateProbabilities();
        }
        return this.trigramWeights;
    }
}

