test_stack_decoder.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: Stack decoder
  3. #
  4. # Copyright (C) 2001-2020 NLTK Project
  5. # Author: Tah Wei Hoon <hoon.tw@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. """
  9. Tests for stack decoder
  10. """
  11. import unittest
  12. from collections import defaultdict
  13. from math import log
  14. from nltk.translate import PhraseTable
  15. from nltk.translate import StackDecoder
  16. from nltk.translate.stack_decoder import _Hypothesis, _Stack
  17. class TestStackDecoder(unittest.TestCase):
  18. def test_find_all_src_phrases(self):
  19. # arrange
  20. phrase_table = TestStackDecoder.create_fake_phrase_table()
  21. stack_decoder = StackDecoder(phrase_table, None)
  22. sentence = ('my', 'hovercraft', 'is', 'full', 'of', 'eels')
  23. # act
  24. src_phrase_spans = stack_decoder.find_all_src_phrases(sentence)
  25. # assert
  26. self.assertEqual(src_phrase_spans[0], [2]) # 'my hovercraft'
  27. self.assertEqual(src_phrase_spans[1], [2]) # 'hovercraft'
  28. self.assertEqual(src_phrase_spans[2], [3]) # 'is'
  29. self.assertEqual(src_phrase_spans[3], [5, 6]) # 'full of', 'full of eels'
  30. self.assertFalse(src_phrase_spans[4]) # no entry starting with 'of'
  31. self.assertEqual(src_phrase_spans[5], [6]) # 'eels'
  32. def test_distortion_score(self):
  33. # arrange
  34. stack_decoder = StackDecoder(None, None)
  35. stack_decoder.distortion_factor = 0.5
  36. hypothesis = _Hypothesis()
  37. hypothesis.src_phrase_span = (3, 5)
  38. # act
  39. score = stack_decoder.distortion_score(hypothesis, (8, 10))
  40. # assert
  41. expected_score = log(stack_decoder.distortion_factor) * (8 - 5)
  42. self.assertEqual(score, expected_score)
  43. def test_distortion_score_of_first_expansion(self):
  44. # arrange
  45. stack_decoder = StackDecoder(None, None)
  46. stack_decoder.distortion_factor = 0.5
  47. hypothesis = _Hypothesis()
  48. # act
  49. score = stack_decoder.distortion_score(hypothesis, (8, 10))
  50. # assert
  51. # expansion from empty hypothesis always has zero distortion cost
  52. self.assertEqual(score, 0.0)
  53. def test_compute_future_costs(self):
  54. # arrange
  55. phrase_table = TestStackDecoder.create_fake_phrase_table()
  56. language_model = TestStackDecoder.create_fake_language_model()
  57. stack_decoder = StackDecoder(phrase_table, language_model)
  58. sentence = ('my', 'hovercraft', 'is', 'full', 'of', 'eels')
  59. # act
  60. future_scores = stack_decoder.compute_future_scores(sentence)
  61. # assert
  62. self.assertEqual(
  63. future_scores[1][2],
  64. (
  65. phrase_table.translations_for(('hovercraft',))[0].log_prob
  66. + language_model.probability(('hovercraft',))
  67. ),
  68. )
  69. self.assertEqual(
  70. future_scores[0][2],
  71. (
  72. phrase_table.translations_for(('my', 'hovercraft'))[0].log_prob
  73. + language_model.probability(('my', 'hovercraft'))
  74. ),
  75. )
  76. def test_compute_future_costs_for_phrases_not_in_phrase_table(self):
  77. # arrange
  78. phrase_table = TestStackDecoder.create_fake_phrase_table()
  79. language_model = TestStackDecoder.create_fake_language_model()
  80. stack_decoder = StackDecoder(phrase_table, language_model)
  81. sentence = ('my', 'hovercraft', 'is', 'full', 'of', 'eels')
  82. # act
  83. future_scores = stack_decoder.compute_future_scores(sentence)
  84. # assert
  85. self.assertEqual(
  86. future_scores[1][3], # 'hovercraft is' is not in phrase table
  87. future_scores[1][2] + future_scores[2][3],
  88. ) # backoff
  89. def test_future_score(self):
  90. # arrange: sentence with 8 words; words 2, 3, 4 already translated
  91. hypothesis = _Hypothesis()
  92. hypothesis.untranslated_spans = lambda _: [(0, 2), (5, 8)] # mock
  93. future_score_table = defaultdict(lambda: defaultdict(float))
  94. future_score_table[0][2] = 0.4
  95. future_score_table[5][8] = 0.5
  96. stack_decoder = StackDecoder(None, None)
  97. # act
  98. future_score = stack_decoder.future_score(hypothesis, future_score_table, 8)
  99. # assert
  100. self.assertEqual(future_score, 0.4 + 0.5)
  101. def test_valid_phrases(self):
  102. # arrange
  103. hypothesis = _Hypothesis()
  104. # mock untranslated_spans method
  105. hypothesis.untranslated_spans = lambda _: [(0, 2), (3, 6)]
  106. all_phrases_from = [[1, 4], [2], [], [5], [5, 6, 7], [], [7]]
  107. # act
  108. phrase_spans = StackDecoder.valid_phrases(all_phrases_from, hypothesis)
  109. # assert
  110. self.assertEqual(phrase_spans, [(0, 1), (1, 2), (3, 5), (4, 5), (4, 6)])
  111. @staticmethod
  112. def create_fake_phrase_table():
  113. phrase_table = PhraseTable()
  114. phrase_table.add(('hovercraft',), ('',), 0.8)
  115. phrase_table.add(('my', 'hovercraft'), ('', ''), 0.7)
  116. phrase_table.add(('my', 'cheese'), ('', ''), 0.7)
  117. phrase_table.add(('is',), ('',), 0.8)
  118. phrase_table.add(('is',), ('',), 0.5)
  119. phrase_table.add(('full', 'of'), ('', ''), 0.01)
  120. phrase_table.add(('full', 'of', 'eels'), ('', '', ''), 0.5)
  121. phrase_table.add(('full', 'of', 'spam'), ('', ''), 0.5)
  122. phrase_table.add(('eels',), ('',), 0.5)
  123. phrase_table.add(('spam',), ('',), 0.5)
  124. return phrase_table
  125. @staticmethod
  126. def create_fake_language_model():
  127. # nltk.model should be used here once it is implemented
  128. language_prob = defaultdict(lambda: -999.0)
  129. language_prob[('my',)] = log(0.1)
  130. language_prob[('hovercraft',)] = log(0.1)
  131. language_prob[('is',)] = log(0.1)
  132. language_prob[('full',)] = log(0.1)
  133. language_prob[('of',)] = log(0.1)
  134. language_prob[('eels',)] = log(0.1)
  135. language_prob[('my', 'hovercraft')] = log(0.3)
  136. language_model = type(
  137. '', (object,), {'probability': lambda _, phrase: language_prob[phrase]}
  138. )()
  139. return language_model
  140. class TestHypothesis(unittest.TestCase):
  141. def setUp(self):
  142. root = _Hypothesis()
  143. child = _Hypothesis(
  144. raw_score=0.5,
  145. src_phrase_span=(3, 7),
  146. trg_phrase=('hello', 'world'),
  147. previous=root,
  148. )
  149. grandchild = _Hypothesis(
  150. raw_score=0.4,
  151. src_phrase_span=(1, 2),
  152. trg_phrase=('and', 'goodbye'),
  153. previous=child,
  154. )
  155. self.hypothesis_chain = grandchild
  156. def test_translation_so_far(self):
  157. # act
  158. translation = self.hypothesis_chain.translation_so_far()
  159. # assert
  160. self.assertEqual(translation, ['hello', 'world', 'and', 'goodbye'])
  161. def test_translation_so_far_for_empty_hypothesis(self):
  162. # arrange
  163. hypothesis = _Hypothesis()
  164. # act
  165. translation = hypothesis.translation_so_far()
  166. # assert
  167. self.assertEqual(translation, [])
  168. def test_total_translated_words(self):
  169. # act
  170. total_translated_words = self.hypothesis_chain.total_translated_words()
  171. # assert
  172. self.assertEqual(total_translated_words, 5)
  173. def test_translated_positions(self):
  174. # act
  175. translated_positions = self.hypothesis_chain.translated_positions()
  176. # assert
  177. translated_positions.sort()
  178. self.assertEqual(translated_positions, [1, 3, 4, 5, 6])
  179. def test_untranslated_spans(self):
  180. # act
  181. untranslated_spans = self.hypothesis_chain.untranslated_spans(10)
  182. # assert
  183. self.assertEqual(untranslated_spans, [(0, 1), (2, 3), (7, 10)])
  184. def test_untranslated_spans_for_empty_hypothesis(self):
  185. # arrange
  186. hypothesis = _Hypothesis()
  187. # act
  188. untranslated_spans = hypothesis.untranslated_spans(10)
  189. # assert
  190. self.assertEqual(untranslated_spans, [(0, 10)])
  191. class TestStack(unittest.TestCase):
  192. def test_push_bumps_off_worst_hypothesis_when_stack_is_full(self):
  193. # arrange
  194. stack = _Stack(3)
  195. poor_hypothesis = _Hypothesis(0.01)
  196. # act
  197. stack.push(_Hypothesis(0.2))
  198. stack.push(poor_hypothesis)
  199. stack.push(_Hypothesis(0.1))
  200. stack.push(_Hypothesis(0.3))
  201. # assert
  202. self.assertFalse(poor_hypothesis in stack)
  203. def test_push_removes_hypotheses_that_fall_below_beam_threshold(self):
  204. # arrange
  205. stack = _Stack(3, 0.5)
  206. poor_hypothesis = _Hypothesis(0.01)
  207. worse_hypothesis = _Hypothesis(0.009)
  208. # act
  209. stack.push(poor_hypothesis)
  210. stack.push(worse_hypothesis)
  211. stack.push(_Hypothesis(0.9)) # greatly superior hypothesis
  212. # assert
  213. self.assertFalse(poor_hypothesis in stack)
  214. self.assertFalse(worse_hypothesis in stack)
  215. def test_push_does_not_add_hypothesis_that_falls_below_beam_threshold(self):
  216. # arrange
  217. stack = _Stack(3, 0.5)
  218. poor_hypothesis = _Hypothesis(0.01)
  219. # act
  220. stack.push(_Hypothesis(0.9)) # greatly superior hypothesis
  221. stack.push(poor_hypothesis)
  222. # assert
  223. self.assertFalse(poor_hypothesis in stack)
  224. def test_best_returns_the_best_hypothesis(self):
  225. # arrange
  226. stack = _Stack(3)
  227. best_hypothesis = _Hypothesis(0.99)
  228. # act
  229. stack.push(_Hypothesis(0.0))
  230. stack.push(best_hypothesis)
  231. stack.push(_Hypothesis(0.5))
  232. # assert
  233. self.assertEqual(stack.best(), best_hypothesis)
  234. def test_best_returns_none_when_stack_is_empty(self):
  235. # arrange
  236. stack = _Stack(3)
  237. # assert
  238. self.assertEqual(stack.best(), None)