test_counter.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Natural Language Toolkit: Language Model Unit Tests
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. import unittest
  8. from nltk import FreqDist
  9. from nltk.lm import NgramCounter
  10. from nltk.util import everygrams
  11. class NgramCounterTests(unittest.TestCase):
  12. """Tests for NgramCounter that only involve lookup, no modification."""
  13. @classmethod
  14. def setUpClass(cls):
  15. text = [list("abcd"), list("egdbe")]
  16. cls.trigram_counter = NgramCounter(
  17. (everygrams(sent, max_len=3) for sent in text)
  18. )
  19. cls.bigram_counter = NgramCounter(
  20. (everygrams(sent, max_len=2) for sent in text)
  21. )
  22. def test_N(self):
  23. self.assertEqual(self.bigram_counter.N(), 16)
  24. self.assertEqual(self.trigram_counter.N(), 21)
  25. def test_counter_len_changes_with_lookup(self):
  26. self.assertEqual(len(self.bigram_counter), 2)
  27. _ = self.bigram_counter[50]
  28. self.assertEqual(len(self.bigram_counter), 3)
  29. def test_ngram_order_access_unigrams(self):
  30. self.assertEqual(self.bigram_counter[1], self.bigram_counter.unigrams)
  31. def test_ngram_conditional_freqdist(self):
  32. expected_trigram_contexts = [
  33. ("a", "b"),
  34. ("b", "c"),
  35. ("e", "g"),
  36. ("g", "d"),
  37. ("d", "b"),
  38. ]
  39. expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)]
  40. bigrams = self.trigram_counter[2]
  41. trigrams = self.trigram_counter[3]
  42. self.assertCountEqual(expected_bigram_contexts, bigrams.conditions())
  43. self.assertCountEqual(expected_trigram_contexts, trigrams.conditions())
  44. def test_bigram_counts_seen_ngrams(self):
  45. b_given_a_count = 1
  46. unk_given_b_count = 1
  47. self.assertEqual(b_given_a_count, self.bigram_counter[["a"]]["b"])
  48. self.assertEqual(unk_given_b_count, self.bigram_counter[["b"]]["c"])
  49. def test_bigram_counts_unseen_ngrams(self):
  50. z_given_b_count = 0
  51. self.assertEqual(z_given_b_count, self.bigram_counter[["b"]]["z"])
  52. def test_unigram_counts_seen_words(self):
  53. expected_count_b = 2
  54. self.assertEqual(expected_count_b, self.bigram_counter["b"])
  55. def test_unigram_counts_completely_unseen_words(self):
  56. unseen_count = 0
  57. self.assertEqual(unseen_count, self.bigram_counter["z"])
  58. class NgramCounterTrainingTests(unittest.TestCase):
  59. def setUp(self):
  60. self.counter = NgramCounter()
  61. def test_empty_string(self):
  62. test = NgramCounter("")
  63. self.assertNotIn(2, test)
  64. self.assertEqual(test[1], FreqDist())
  65. def test_empty_list(self):
  66. test = NgramCounter([])
  67. self.assertNotIn(2, test)
  68. self.assertEqual(test[1], FreqDist())
  69. def test_None(self):
  70. test = NgramCounter(None)
  71. self.assertNotIn(2, test)
  72. self.assertEqual(test[1], FreqDist())
  73. def test_train_on_unigrams(self):
  74. words = list("abcd")
  75. counter = NgramCounter([[(w,) for w in words]])
  76. self.assertFalse(bool(counter[3]))
  77. self.assertFalse(bool(counter[2]))
  78. self.assertCountEqual(words, counter[1].keys())
  79. def test_train_on_illegal_sentences(self):
  80. str_sent = ["Check", "this", "out", "!"]
  81. list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]]
  82. with self.assertRaises(TypeError):
  83. NgramCounter([str_sent])
  84. with self.assertRaises(TypeError):
  85. NgramCounter([list_sent])
  86. def test_train_on_bigrams(self):
  87. bigram_sent = [("a", "b"), ("c", "d")]
  88. counter = NgramCounter([bigram_sent])
  89. self.assertFalse(bool(counter[3]))
  90. def test_train_on_mix(self):
  91. mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)]
  92. counter = NgramCounter([mixed_sent])
  93. unigrams = ["h"]
  94. bigram_contexts = [("a",), ("c",)]
  95. trigram_contexts = [("e", "f")]
  96. self.assertCountEqual(unigrams, counter[1].keys())
  97. self.assertCountEqual(bigram_contexts, counter[2].keys())
  98. self.assertCountEqual(trigram_contexts, counter[3].keys())