diff --git a/requirements.txt b/requirements.txt index 58c53ba..6fb272a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pandas +pandas==0.16.2 scikit-learn statsmodels gensim @@ -7,3 +7,6 @@ pyth pymongo MySQL-python scipy +unidecode +multiprocess +nltk diff --git a/rosetta/tests/test_text.py b/rosetta/tests/test_text.py index d26d470..3bf9f90 100644 --- a/rosetta/tests/test_text.py +++ b/rosetta/tests/test_text.py @@ -323,16 +323,6 @@ def test_dirichlet_expectation(self): [-0.13470677, -13.32429878]]).T assert_allclose(result, benchmark, atol=1e-4) - def test_expElogbeta(self): - # Make sure equal to exponential of dirichlet_expectation when we - # pass in all ones - lda = self.choose_lda('lda') - lda._lambda_word_sums = pd.Series( - np.ones(lda.num_topics), index=lda.topics) - result = lda._expElogbeta - benchmark = np.exp(lda._dirichlet_expectation(lda.pr_token_topic)) - assert_frame_equal(result, benchmark) - def test_predict_1(self): # Use fact that w0 <--> topic_0, w1 <--> topic_1 lda = self.choose_lda('lda_2') diff --git a/rosetta/text/vw_helpers.py b/rosetta/text/vw_helpers.py index dfcb857..93f8efc 100644 --- a/rosetta/text/vw_helpers.py +++ b/rosetta/text/vw_helpers.py @@ -332,6 +332,9 @@ def _set_probabilities(self, topics, predictions): self.pr_doc = doc_sums / doc_sums.sum() self.pr_doc_topic = predictions / predictions.sum().sum() + lam = self._lambda_word_sums * self.pr_token_topic + self._constExpElogbeta = np.exp(self._dirichlet_expectation(lam + EPS)) + def prob_token_topic(self, token=None, topic=None, c_token=None, c_topic=None): """ @@ -565,9 +568,7 @@ def _expElogbeta(self): topic-word weights. """ # Get lambda, the dirichlet parameter originally returned by VW. - lam = self._lambda_word_sums * self.pr_token_topic - - return np.exp(self._dirichlet_expectation(lam + EPS)) + return self._constExpElogbeta def _dirichlet_expectation(self, alpha): """