test_models.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Natural Language Toolkit: Language Model Unit Tests
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. import math
  8. import unittest
  9. from nltk.lm import (
  10. Vocabulary,
  11. MLE,
  12. Lidstone,
  13. Laplace,
  14. WittenBellInterpolated,
  15. KneserNeyInterpolated,
  16. )
  17. from nltk.lm.preprocessing import padded_everygrams
  18. def _prepare_test_data(ngram_order):
  19. return (
  20. Vocabulary(["a", "b", "c", "d", "z", "<s>", "</s>"], unk_cutoff=1),
  21. [
  22. list(padded_everygrams(ngram_order, sent))
  23. for sent in (list("abcd"), list("egadbe"))
  24. ],
  25. )
  26. class ParametrizeTestsMeta(type):
  27. """Metaclass for generating parametrized tests."""
  28. def __new__(cls, name, bases, dct):
  29. contexts = (
  30. ("a",),
  31. ("c",),
  32. (u"<s>",),
  33. ("b",),
  34. (u"<UNK>",),
  35. ("d",),
  36. ("e",),
  37. ("r",),
  38. ("w",),
  39. )
  40. for i, c in enumerate(contexts):
  41. dct["test_sumto1_{0}".format(i)] = cls.add_sum_to_1_test(c)
  42. scores = dct.get("score_tests", [])
  43. for i, (word, context, expected_score) in enumerate(scores):
  44. dct["test_score_{0}".format(i)] = cls.add_score_test(
  45. word, context, expected_score
  46. )
  47. return super().__new__(cls, name, bases, dct)
  48. @classmethod
  49. def add_score_test(cls, word, context, expected_score):
  50. message = "word='{word}', context={context}"
  51. def test_method(self):
  52. score = self.model.score(word, context)
  53. self.assertAlmostEqual(
  54. score, expected_score, msg=message.format(**locals()), places=4
  55. )
  56. return test_method
  57. @classmethod
  58. def add_sum_to_1_test(cls, context):
  59. def test(self):
  60. s = sum(self.model.score(w, context) for w in self.model.vocab)
  61. self.assertAlmostEqual(s, 1.0, msg="The context is {}".format(context))
  62. return test
  63. class MleBigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  64. """Unit tests for MLE ngram model."""
  65. score_tests = [
  66. ("d", ["c"], 1),
  67. # Unseen ngrams should yield 0
  68. ("d", ["e"], 0),
  69. # Unigrams should also be 0
  70. ("z", None, 0),
  71. # N unigrams = 14
  72. # count('a') = 2
  73. ("a", None, 2.0 / 14),
  74. # count('y') = 3
  75. ("y", None, 3.0 / 14),
  76. ]
  77. def setUp(self):
  78. vocab, training_text = _prepare_test_data(2)
  79. self.model = MLE(2, vocabulary=vocab)
  80. self.model.fit(training_text)
  81. def test_logscore_zero_score(self):
  82. # logscore of unseen ngrams should be -inf
  83. logscore = self.model.logscore("d", ["e"])
  84. self.assertTrue(math.isinf(logscore))
  85. def test_entropy_perplexity_seen(self):
  86. # ngrams seen during training
  87. trained = [
  88. ("<s>", "a"),
  89. ("a", "b"),
  90. ("b", "<UNK>"),
  91. ("<UNK>", "a"),
  92. ("a", "d"),
  93. ("d", "</s>"),
  94. ]
  95. # Ngram = Log score
  96. # <s>, a = -1
  97. # a, b = -1
  98. # b, UNK = -1
  99. # UNK, a = -1.585
  100. # a, d = -1
  101. # d, </s> = -1
  102. # TOTAL logscores = -6.585
  103. # - AVG logscores = 1.0975
  104. H = 1.0975
  105. perplexity = 2.1398
  106. self.assertAlmostEqual(H, self.model.entropy(trained), places=4)
  107. self.assertAlmostEqual(perplexity, self.model.perplexity(trained), places=4)
  108. def test_entropy_perplexity_unseen(self):
  109. # In MLE, even one unseen ngram should make entropy and perplexity infinite
  110. untrained = [("<s>", "a"), ("a", "c"), ("c", "d"), ("d", "</s>")]
  111. self.assertTrue(math.isinf(self.model.entropy(untrained)))
  112. self.assertTrue(math.isinf(self.model.perplexity(untrained)))
  113. def test_entropy_perplexity_unigrams(self):
  114. # word = score, log score
  115. # <s> = 0.1429, -2.8074
  116. # a = 0.1429, -2.8074
  117. # c = 0.0714, -3.8073
  118. # UNK = 0.2143, -2.2224
  119. # d = 0.1429, -2.8074
  120. # c = 0.0714, -3.8073
  121. # </s> = 0.1429, -2.8074
  122. # TOTAL logscores = -21.6243
  123. # - AVG logscores = 3.0095
  124. H = 3.0095
  125. perplexity = 8.0529
  126. text = [("<s>",), ("a",), ("c",), ("-",), ("d",), ("c",), ("</s>",)]
  127. self.assertAlmostEqual(H, self.model.entropy(text), places=4)
  128. self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4)
  129. class MleTrigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  130. """MLE trigram model tests"""
  131. score_tests = [
  132. # count(d | b, c) = 1
  133. # count(b, c) = 1
  134. ("d", ("b", "c"), 1),
  135. # count(d | c) = 1
  136. # count(c) = 1
  137. ("d", ["c"], 1),
  138. # total number of tokens is 18, of which "a" occured 2 times
  139. ("a", None, 2.0 / 18),
  140. # in vocabulary but unseen
  141. ("z", None, 0),
  142. # out of vocabulary should use "UNK" score
  143. ("y", None, 3.0 / 18),
  144. ]
  145. def setUp(self):
  146. vocab, training_text = _prepare_test_data(3)
  147. self.model = MLE(3, vocabulary=vocab)
  148. self.model.fit(training_text)
  149. class LidstoneBigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  150. """Unit tests for Lidstone class"""
  151. score_tests = [
  152. # count(d | c) = 1
  153. # *count(d | c) = 1.1
  154. # Count(w | c for w in vocab) = 1
  155. # *Count(w | c for w in vocab) = 1.8
  156. ("d", ["c"], 1.1 / 1.8),
  157. # Total unigrams: 14
  158. # Vocab size: 8
  159. # Denominator: 14 + 0.8 = 14.8
  160. # count("a") = 2
  161. # *count("a") = 2.1
  162. ("a", None, 2.1 / 14.8),
  163. # in vocabulary but unseen
  164. # count("z") = 0
  165. # *count("z") = 0.1
  166. ("z", None, 0.1 / 14.8),
  167. # out of vocabulary should use "UNK" score
  168. # count("<UNK>") = 3
  169. # *count("<UNK>") = 3.1
  170. ("y", None, 3.1 / 14.8),
  171. ]
  172. def setUp(self):
  173. vocab, training_text = _prepare_test_data(2)
  174. self.model = Lidstone(0.1, 2, vocabulary=vocab)
  175. self.model.fit(training_text)
  176. def test_gamma(self):
  177. self.assertEqual(0.1, self.model.gamma)
  178. def test_entropy_perplexity(self):
  179. text = [
  180. ("<s>", "a"),
  181. ("a", "c"),
  182. ("c", "<UNK>"),
  183. ("<UNK>", "d"),
  184. ("d", "c"),
  185. ("c", "</s>"),
  186. ]
  187. # Unlike MLE this should be able to handle completely novel ngrams
  188. # Ngram = score, log score
  189. # <s>, a = 0.3929, -1.3479
  190. # a, c = 0.0357, -4.8074
  191. # c, UNK = 0.0(5), -4.1699
  192. # UNK, d = 0.0263, -5.2479
  193. # d, c = 0.0357, -4.8074
  194. # c, </s> = 0.0(5), -4.1699
  195. # TOTAL logscore: −24.5504
  196. # - AVG logscore: 4.0917
  197. H = 4.0917
  198. perplexity = 17.0504
  199. self.assertAlmostEqual(H, self.model.entropy(text), places=4)
  200. self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4)
  201. class LidstoneTrigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  202. score_tests = [
  203. # Logic behind this is the same as for bigram model
  204. ("d", ["c"], 1.1 / 1.8),
  205. # if we choose a word that hasn't appeared after (b, c)
  206. ("e", ["c"], 0.1 / 1.8),
  207. # Trigram score now
  208. ("d", ["b", "c"], 1.1 / 1.8),
  209. ("e", ["b", "c"], 0.1 / 1.8),
  210. ]
  211. def setUp(self):
  212. vocab, training_text = _prepare_test_data(3)
  213. self.model = Lidstone(0.1, 3, vocabulary=vocab)
  214. self.model.fit(training_text)
  215. class LaplaceBigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  216. """Unit tests for Laplace class"""
  217. score_tests = [
  218. # basic sanity-check:
  219. # count(d | c) = 1
  220. # *count(d | c) = 2
  221. # Count(w | c for w in vocab) = 1
  222. # *Count(w | c for w in vocab) = 9
  223. ("d", ["c"], 2.0 / 9),
  224. # Total unigrams: 14
  225. # Vocab size: 8
  226. # Denominator: 14 + 8 = 22
  227. # count("a") = 2
  228. # *count("a") = 3
  229. ("a", None, 3.0 / 22),
  230. # in vocabulary but unseen
  231. # count("z") = 0
  232. # *count("z") = 1
  233. ("z", None, 1.0 / 22),
  234. # out of vocabulary should use "UNK" score
  235. # count("<UNK>") = 3
  236. # *count("<UNK>") = 4
  237. ("y", None, 4.0 / 22),
  238. ]
  239. def setUp(self):
  240. vocab, training_text = _prepare_test_data(2)
  241. self.model = Laplace(2, vocabulary=vocab)
  242. self.model.fit(training_text)
  243. def test_gamma(self):
  244. # Make sure the gamma is set to 1
  245. self.assertEqual(1, self.model.gamma)
  246. def test_entropy_perplexity(self):
  247. text = [
  248. ("<s>", "a"),
  249. ("a", "c"),
  250. ("c", "<UNK>"),
  251. ("<UNK>", "d"),
  252. ("d", "c"),
  253. ("c", "</s>"),
  254. ]
  255. # Unlike MLE this should be able to handle completely novel ngrams
  256. # Ngram = score, log score
  257. # <s>, a = 0.2, -2.3219
  258. # a, c = 0.1, -3.3219
  259. # c, UNK = 0.(1), -3.1699
  260. # UNK, d = 0.(09), 3.4594
  261. # d, c = 0.1 -3.3219
  262. # c, </s> = 0.(1), -3.1699
  263. # Total logscores: −18.7651
  264. # - AVG logscores: 3.1275
  265. H = 3.1275
  266. perplexity = 8.7393
  267. self.assertAlmostEqual(H, self.model.entropy(text), places=4)
  268. self.assertAlmostEqual(perplexity, self.model.perplexity(text), places=4)
  269. class WittenBellInterpolatedTrigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  270. def setUp(self):
  271. vocab, training_text = _prepare_test_data(3)
  272. self.model = WittenBellInterpolated(3, vocabulary=vocab)
  273. self.model.fit(training_text)
  274. score_tests = [
  275. # For unigram scores by default revert to MLE
  276. # Total unigrams: 18
  277. # count('c'): 1
  278. ("c", None, 1.0 / 18),
  279. # in vocabulary but unseen
  280. # count("z") = 0
  281. ("z", None, 0.0 / 18),
  282. # out of vocabulary should use "UNK" score
  283. # count("<UNK>") = 3
  284. ("y", None, 3.0 / 18),
  285. # gamma(['b']) = 0.1111
  286. # mle.score('c', ['b']) = 0.5
  287. # (1 - gamma) * mle + gamma * mle('c') ~= 0.45 + .3 / 18
  288. ("c", ["b"], (1 - 0.1111) * 0.5 + 0.1111 * 1 / 18),
  289. # building on that, let's try 'a b c' as the trigram
  290. # gamma(['a', 'b']) = 0.0667
  291. # mle("c", ["a", "b"]) = 1
  292. ("c", ["a", "b"], (1 - 0.0667) + 0.0667 * ((1 - 0.1111) * 0.5 + 0.1111 / 18)),
  293. # The ngram 'z b c' was not seen, so we should simply revert to
  294. # the score of the ngram 'b c'. See issue #2332.
  295. ("c", ["z", "b"], ((1 - 0.1111) * 0.5 + 0.1111 / 18)),
  296. ]
  297. class KneserNeyInterpolatedTrigramTests(unittest.TestCase, metaclass=ParametrizeTestsMeta):
  298. def setUp(self):
  299. vocab, training_text = _prepare_test_data(3)
  300. self.model = KneserNeyInterpolated(3, vocabulary=vocab)
  301. self.model.fit(training_text)
  302. score_tests = [
  303. # For unigram scores revert to uniform
  304. # Vocab size: 8
  305. # count('c'): 1
  306. ("c", None, 1.0 / 8),
  307. # in vocabulary but unseen, still uses uniform
  308. ("z", None, 1 / 8),
  309. # out of vocabulary should use "UNK" score, i.e. again uniform
  310. ("y", None, 1.0 / 8),
  311. # alpha = count('bc') - discount = 1 - 0.1 = 0.9
  312. # gamma(['b']) = discount * number of unique words that follow ['b'] = 0.1 * 2
  313. # normalizer = total number of bigrams with this context = 2
  314. # the final should be: (alpha + gamma * unigram_score("c"))
  315. ("c", ["b"], (0.9 + 0.2 * (1 / 8)) / 2),
  316. # building on that, let's try 'a b c' as the trigram
  317. # alpha = count('abc') - discount = 1 - 0.1 = 0.9
  318. # gamma(['a', 'b']) = 0.1 * 1
  319. # normalizer = total number of trigrams with prefix "ab" = 1 => we can ignore it!
  320. ("c", ["a", "b"], 0.9 + 0.1 * ((0.9 + 0.2 * (1 / 8)) / 2)),
  321. # The ngram 'z b c' was not seen, so we should simply revert to
  322. # the score of the ngram 'b c'. See issue #2332.
  323. ("c", ["z", "b"], ((0.9 + 0.2 * (1 / 8)) / 2)),
  324. ]
  325. class NgramModelTextGenerationTests(unittest.TestCase):
  326. """Using MLE model, generate some text."""
  327. def setUp(self):
  328. vocab, training_text = _prepare_test_data(3)
  329. self.model = MLE(3, vocabulary=vocab)
  330. self.model.fit(training_text)
  331. def test_generate_one_no_context(self):
  332. self.assertEqual(self.model.generate(random_seed=3), "<UNK>")
  333. def test_generate_one_limiting_context(self):
  334. # We don't need random_seed for contexts with only one continuation
  335. self.assertEqual(self.model.generate(text_seed=["c"]), "d")
  336. self.assertEqual(self.model.generate(text_seed=["b", "c"]), "d")
  337. self.assertEqual(self.model.generate(text_seed=["a", "c"]), "d")
  338. def test_generate_one_varied_context(self):
  339. # When context doesn't limit our options enough, seed the random choice
  340. self.assertEqual(
  341. self.model.generate(text_seed=("a", "<s>"), random_seed=2), "a"
  342. )
  343. def test_generate_cycle(self):
  344. # Add a cycle to the model: bd -> b, db -> d
  345. more_training_text = [list(padded_everygrams(self.model.order, list("bdbdbd")))]
  346. self.model.fit(more_training_text)
  347. # Test that we can escape the cycle
  348. self.assertEqual(
  349. self.model.generate(7, text_seed=("b", "d"), random_seed=5),
  350. ["b", "d", "b", "d", "b", "d", "</s>"],
  351. )
  352. def test_generate_with_text_seed(self):
  353. self.assertEqual(
  354. self.model.generate(5, text_seed=("<s>", "e"), random_seed=3),
  355. ["<UNK>", "a", "d", "b", "<UNK>"],
  356. )
  357. def test_generate_oov_text_seed(self):
  358. self.assertEqual(
  359. self.model.generate(text_seed=("aliens",), random_seed=3),
  360. self.model.generate(text_seed=("<UNK>",), random_seed=3),
  361. )
  362. def test_generate_None_text_seed(self):
  363. # should crash with type error when we try to look it up in vocabulary
  364. with self.assertRaises(TypeError):
  365. self.model.generate(text_seed=(None,))
  366. # This will work
  367. self.assertEqual(
  368. self.model.generate(text_seed=None, random_seed=3),
  369. self.model.generate(random_seed=3),
  370. )