/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.WordEmbeddingCallable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.PrintWriter;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Formatter;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class WordEmbeddings {
    static CommandOption.String inputFile = new CommandOption.String(WordEmbeddings.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
    static CommandOption.String outputFile = new CommandOption.String(WordEmbeddings.class, "output", "FILENAME", true, "weights.txt", "The filename to write text-formatted word vectors.", null);
    static CommandOption.String outputContextFile = new CommandOption.String(WordEmbeddings.class, "output-context", "FILENAME", true, "NONE", "The filename to write text-formatted context vectors.", null);
    static CommandOption.Boolean outputStatsLine = new CommandOption.Boolean(WordEmbeddings.class, "output-stats-prefix", "TRUE/FALSE", false, false, "Whether to include a line at the beginning of the output with the vocab size and vector dimension, for compatibility with other packages.", null);
    static CommandOption.Integer numDimensions = new CommandOption.Integer(WordEmbeddings.class, "num-dimensions", "INTEGER", true, 50, "The number of dimensions to fit.", null);
    static CommandOption.Integer windowSizeOption = new CommandOption.Integer(WordEmbeddings.class, "window-size", "INTEGER", true, 5, "The number of adjacent words to consider.", null);
    static CommandOption.Integer numThreads = new CommandOption.Integer(WordEmbeddings.class, "num-threads", "INTEGER", true, 1, "The number of threads for parallel training.", null);
    static CommandOption.Integer numIterationsOption = new CommandOption.Integer(WordEmbeddings.class, "num-iters", "INTEGER", true, 3, "The number of passes through the training data.", null);
    static CommandOption.Double samplingFactorOption = new CommandOption.Double(WordEmbeddings.class, "frequency-factor", "NUMBER", true, 1.0E-4, "Down-sample words that account for more than ~2.5x this proportion or the corpus.", null);
    static CommandOption.Integer numSamples = new CommandOption.Integer(WordEmbeddings.class, "num-samples", "INTEGER", true, 5, "The number of negative samples to use in training.", null);
    static CommandOption.String exampleWord = new CommandOption.String(WordEmbeddings.class, "example-word", "STRING", true, null, "If defined, periodically show the closest vectors to this word.", null);
    static CommandOption.String orderingOption = new CommandOption.String(WordEmbeddings.class, "ordering", "STRING", true, "linear", "\"linear\" reads documents in order, \"shuffled\" reads in random order, \"random\" selects documents at random and may repeat/drop documents", null);
    public static final int LINEAR_ORDERING = 0;
    public static final int SHUFFLED_ORDERING = 1;
    public static final int RANDOM_ORDERING = 2;
    Alphabet vocabulary;
    int numWords;
    int numColumns;
    double[] weights;
    double[] negativeWeights;
    int stride;
    int numIterations;
    int[] wordCounts;
    double[] retentionProbability;
    double[] samplingDistribution;
    int[] samplingTable;
    int samplingTableSize = 100000000;
    double samplingSum = 0.0;
    int totalWords = 0;
    double maxExpValue = 6.0;
    double minExpValue = -6.0;
    double[] sigmoidCache;
    int sigmoidCacheSize = 1000;
    int windowSize = 5;
    IDSorter[] sortedWords = null;
    int orderingStrategy = 0;
    private int minDocumentLength = 10;
    String queryWord = "the";
    Randoms random = new Randoms();

    public int getMinDocumentLength() {
        return this.minDocumentLength;
    }

    public void setMinDocumentLength(int minDocumentLength) {
        if (minDocumentLength <= 0) {
            throw new IllegalArgumentException("Minimum document length must be at least 1.");
        }
        this.minDocumentLength = minDocumentLength;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public String getQueryWord() {
        return this.queryWord;
    }

    public void setQueryWord(String queryWord) {
        this.queryWord = queryWord;
    }

    public WordEmbeddings() {
    }

    public WordEmbeddings(Alphabet a, int numColumns, int windowSize) {
        this.vocabulary = a;
        this.numWords = this.vocabulary.size();
        System.out.format("Vocab size: %d\n", this.numWords);
        this.numColumns = numColumns;
        this.stride = numColumns;
        this.weights = new double[this.numWords * this.stride];
        this.negativeWeights = new double[this.numWords * this.stride];
        for (int word = 0; word < this.numWords; ++word) {
            for (int col = 0; col < numColumns; ++col) {
                this.weights[word * this.stride + col] = (this.random.nextDouble() - 0.5) / (double)numColumns;
                this.negativeWeights[word * this.stride + col] = 0.0;
            }
        }
        this.wordCounts = new int[this.numWords];
        this.samplingDistribution = new double[this.numWords];
        this.retentionProbability = new double[this.numWords];
        this.samplingTable = new int[this.samplingTableSize];
        this.windowSize = windowSize;
        this.sigmoidCache = new double[this.sigmoidCacheSize + 1];
        for (int i = 0; i < this.sigmoidCacheSize; ++i) {
            double value = (double)i / (double)this.sigmoidCacheSize * (this.maxExpValue - this.minExpValue) + this.minExpValue;
            this.sigmoidCache[i] = 1.0 / (1.0 + Math.exp(-value));
        }
    }

    public void initializeSortables() {
        this.sortedWords = new IDSorter[this.numWords];
        for (int word = 0; word < this.numWords; ++word) {
            this.sortedWords[word] = new IDSorter(word, 0.0);
        }
    }

    public void countWords(InstanceList instances, double samplingFactor) {
        int word;
        for (Instance instance : instances) {
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            int length = tokens.getLength();
            for (int position = 0; position < length; ++position) {
                int type;
                int n = type = tokens.getIndexAtPosition(position);
                this.wordCounts[n] = this.wordCounts[n] + 1;
            }
            this.totalWords += length;
        }
        for (word = 0; word < this.numWords; ++word) {
            double frequencyScore = (double)this.wordCounts[word] / (samplingFactor * (double)this.totalWords);
            this.retentionProbability[word] = Math.min((Math.sqrt(frequencyScore) + 1.0) / frequencyScore, 1.0);
        }
        if (this.sortedWords == null) {
            this.initializeSortables();
        }
        for (word = 0; word < this.numWords; ++word) {
            this.sortedWords[word].set(word, this.wordCounts[word]);
        }
        Arrays.sort(this.sortedWords);
        this.samplingDistribution[0] = Math.pow(this.wordCounts[this.sortedWords[0].getID()], 0.75);
        for (word = 1; word < this.numWords; ++word) {
            this.samplingDistribution[word] = this.samplingDistribution[word - 1] + Math.pow(this.wordCounts[this.sortedWords[word].getID()], 0.75);
        }
        this.samplingSum = this.samplingDistribution[this.numWords - 1];
        int order = 0;
        for (int i = 0; i < this.samplingTableSize; ++i) {
            this.samplingTable[i] = this.sortedWords[order].getID();
            while (this.samplingSum * (double)i / (double)this.samplingTableSize > this.samplingDistribution[order]) {
                ++order;
            }
        }
        System.out.println("done counting: " + this.totalWords);
    }

    public void train(InstanceList instances, int numThreads, int numSamples) {
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);
        WordEmbeddingCallable[] callables = new WordEmbeddingCallable[numThreads];
        for (int thread = 0; thread < numThreads; ++thread) {
            callables[thread] = new WordEmbeddingCallable(this, instances, numSamples, numThreads, thread);
            callables[thread].setOrdering(this.orderingStrategy);
        }
        long startTime = System.currentTimeMillis();
        double difference = 0.0;
        for (int iteration = 0; iteration < this.numIterations; ++iteration) {
            long wordsSoFar = 0L;
            try {
                List futures = executor.invokeAll(Arrays.asList(callables));
                for (Future future : futures) {
                    wordsSoFar += ((Long)future.get()).longValue();
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            long runningMillis = System.currentTimeMillis() - startTime;
            System.out.format("%d\t%d\t%fk w/s %.3f avg\n", wordsSoFar, runningMillis, (double)wordsSoFar / (double)runningMillis, this.averageAbsWeight());
            if (this.queryWord == null || !this.vocabulary.contains(this.queryWord)) continue;
            this.findClosest(this.copy(this.queryWord));
        }
        executor.shutdownNow();
    }

    public void findClosest(double[] targetVector) {
        if (this.sortedWords == null) {
            this.initializeSortables();
        }
        double targetSquaredSum = 0.0;
        for (int col = 0; col < this.numColumns; ++col) {
            targetSquaredSum += targetVector[col] * targetVector[col];
        }
        double targetNormalizer = 1.0 / Math.sqrt(targetSquaredSum);
        for (int word = 0; word < this.numWords; ++word) {
            double innerProduct = 0.0;
            double wordSquaredSum = 0.0;
            for (int col = 0; col < this.numColumns; ++col) {
                wordSquaredSum += this.weights[word * this.stride + col] * this.weights[word * this.stride + col];
            }
            double wordNormalizer = 1.0 / Math.sqrt(wordSquaredSum);
            for (int col = 0; col < this.numColumns; ++col) {
                innerProduct += targetVector[col] * this.weights[word * this.stride + col];
            }
            this.sortedWords[word].set(word, innerProduct *= targetNormalizer * wordNormalizer);
        }
        Arrays.sort(this.sortedWords);
        for (int i = 0; i < 10; ++i) {
            System.out.format("%f\t%d\t%s\n", this.sortedWords[i].getWeight(), this.sortedWords[i].getID(), this.vocabulary.lookupObject(this.sortedWords[i].getID()));
        }
    }

    public double averageAbsWeight() {
        double sum = 0.0;
        for (int word = 0; word < this.numWords; ++word) {
            for (int col = 0; col < this.numColumns; ++col) {
                sum += Math.abs(this.weights[word * this.stride + col]);
            }
        }
        return sum / (double)(this.numWords * this.numColumns);
    }

    public double[] variances() {
        double[] means = new double[this.numColumns];
        for (int word = 0; word < this.numWords; ++word) {
            for (int col = 0; col < this.numColumns; ++col) {
                int n = col;
                means[n] = means[n] + this.weights[word * this.stride + col];
            }
        }
        int col = 0;
        while (col < this.numColumns) {
            int n = col++;
            means[n] = means[n] / (double)this.numWords;
        }
        double[] squaredSums = new double[this.numColumns];
        for (int word = 0; word < this.numWords; ++word) {
            int col2 = 0;
            while (col2 < this.numColumns) {
                double diff = this.weights[word * this.stride + col2] - means[col2];
                int n = col2++;
                squaredSums[n] = squaredSums[n] + diff * diff;
            }
        }
        for (int col3 = 0; col3 < this.numColumns; ++col3) {
            int n = col3;
            squaredSums[n] = squaredSums[n] / (double)(this.numWords - 1);
            System.out.format("%f\t", squaredSums[col3]);
        }
        System.out.println();
        return squaredSums;
    }

    public void write(PrintWriter out) {
        for (int word = 0; word < this.numWords; ++word) {
            Formatter buffer = new Formatter(Locale.US);
            buffer.format("%s", this.vocabulary.lookupObject(word));
            for (int col = 0; col < this.numColumns; ++col) {
                buffer.format(" %.6f", this.weights[word * this.stride + col]);
            }
            out.println(buffer);
        }
    }

    public void writeContext(PrintWriter out) {
        for (int word = 0; word < this.numWords; ++word) {
            Formatter buffer = new Formatter(Locale.US);
            buffer.format("%s", this.vocabulary.lookupObject(word));
            for (int col = 0; col < this.numColumns; ++col) {
                buffer.format(" %.6f", this.negativeWeights[word * this.stride + col]);
            }
            out.println(buffer);
        }
    }

    public double[] copy(String word) {
        return this.copy(this.vocabulary.lookupIndex(word));
    }

    public double[] copy(int word) {
        double[] result = new double[this.numColumns];
        for (int col = 0; col < this.numColumns; ++col) {
            result[col] = this.weights[word * this.stride + col];
        }
        return result;
    }

    public double[] add(double[] result, String word) {
        return this.add(result, this.vocabulary.lookupIndex(word));
    }

    public double[] add(double[] result, int word) {
        for (int col = 0; col < this.numColumns; ++col) {
            int n = col;
            result[n] = result[n] + this.weights[word * this.stride + col];
        }
        return result;
    }

    public double[] subtract(double[] result, String word) {
        return this.subtract(result, this.vocabulary.lookupIndex(word));
    }

    public double[] subtract(double[] result, int word) {
        for (int col = 0; col < this.numColumns; ++col) {
            int n = col;
            result[n] = result[n] - this.weights[word * this.stride + col];
        }
        return result;
    }

    public static void main(String[] args) throws Exception {
        CommandOption.setSummary(WordEmbeddings.class, "Train continuous word embeddings using the skip-gram method with negative sampling.");
        CommandOption.process(WordEmbeddings.class, args);
        InstanceList instances = InstanceList.load(new File(WordEmbeddings.inputFile.value));
        WordEmbeddings matrix = new WordEmbeddings(instances.getDataAlphabet(), WordEmbeddings.numDimensions.value, WordEmbeddings.windowSizeOption.value);
        matrix.queryWord = WordEmbeddings.exampleWord.value;
        matrix.setNumIterations(WordEmbeddings.numIterationsOption.value);
        matrix.countWords(instances, WordEmbeddings.samplingFactorOption.value);
        if (WordEmbeddings.orderingOption.value != null) {
            if (WordEmbeddings.orderingOption.value.startsWith("s")) {
                matrix.orderingStrategy = 1;
            } else if (WordEmbeddings.orderingOption.value.startsWith("l")) {
                matrix.orderingStrategy = 0;
            } else if (WordEmbeddings.orderingOption.value.startsWith("r")) {
                matrix.orderingStrategy = 2;
            } else {
                System.err.println("Unrecognized ordering: " + WordEmbeddings.orderingOption.value + ", using linear.");
            }
        }
        matrix.train(instances, WordEmbeddings.numThreads.value, WordEmbeddings.numSamples.value);
        PrintWriter out = new PrintWriter(WordEmbeddings.outputFile.value, Charset.defaultCharset().name());
        if (WordEmbeddings.outputStatsLine.value) {
            out.write(matrix.numWords + " " + matrix.numColumns + "\n");
        }
        matrix.write(out);
        out.close();
        if (WordEmbeddings.outputContextFile.value != null) {
            out = new PrintWriter(WordEmbeddings.outputContextFile.value, Charset.defaultCharset().name());
            if (WordEmbeddings.outputStatsLine.value) {
                out.write(matrix.numWords + " " + matrix.numColumns + "\n");
            }
            matrix.writeContext(out);
            out.close();
        }
    }
}

