test_ibm5.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests for IBM Model 5 training methods
  4. """
  5. import unittest
  6. from collections import defaultdict
  7. from nltk.translate import AlignedSent
  8. from nltk.translate import IBMModel
  9. from nltk.translate import IBMModel4
  10. from nltk.translate import IBMModel5
  11. from nltk.translate.ibm_model import AlignmentInfo
  12. class TestIBMModel5(unittest.TestCase):
  13. def test_set_uniform_vacancy_probabilities_of_max_displacements(self):
  14. # arrange
  15. src_classes = {'schinken': 0, 'eier': 0, 'spam': 1}
  16. trg_classes = {'ham': 0, 'eggs': 1, 'spam': 2}
  17. corpus = [
  18. AlignedSent(['ham', 'eggs'], ['schinken', 'schinken', 'eier']),
  19. AlignedSent(['spam', 'spam', 'spam', 'spam'], ['spam', 'spam']),
  20. ]
  21. model5 = IBMModel5(corpus, 0, src_classes, trg_classes)
  22. # act
  23. model5.set_uniform_probabilities(corpus)
  24. # assert
  25. # number of vacancy difference values =
  26. # 2 * number of words in longest target sentence
  27. expected_prob = 1.0 / (2 * 4)
  28. # examine the boundary values for (dv, max_v, trg_class)
  29. self.assertEqual(model5.head_vacancy_table[4][4][0], expected_prob)
  30. self.assertEqual(model5.head_vacancy_table[-3][1][2], expected_prob)
  31. self.assertEqual(model5.non_head_vacancy_table[4][4][0], expected_prob)
  32. self.assertEqual(model5.non_head_vacancy_table[-3][1][2], expected_prob)
  33. def test_set_uniform_vacancy_probabilities_of_non_domain_values(self):
  34. # arrange
  35. src_classes = {'schinken': 0, 'eier': 0, 'spam': 1}
  36. trg_classes = {'ham': 0, 'eggs': 1, 'spam': 2}
  37. corpus = [
  38. AlignedSent(['ham', 'eggs'], ['schinken', 'schinken', 'eier']),
  39. AlignedSent(['spam', 'spam', 'spam', 'spam'], ['spam', 'spam']),
  40. ]
  41. model5 = IBMModel5(corpus, 0, src_classes, trg_classes)
  42. # act
  43. model5.set_uniform_probabilities(corpus)
  44. # assert
  45. # examine dv and max_v values that are not in the training data domain
  46. self.assertEqual(model5.head_vacancy_table[5][4][0], IBMModel.MIN_PROB)
  47. self.assertEqual(model5.head_vacancy_table[-4][1][2], IBMModel.MIN_PROB)
  48. self.assertEqual(model5.head_vacancy_table[4][0][0], IBMModel.MIN_PROB)
  49. self.assertEqual(model5.non_head_vacancy_table[5][4][0], IBMModel.MIN_PROB)
  50. self.assertEqual(model5.non_head_vacancy_table[-4][1][2], IBMModel.MIN_PROB)
  51. def test_prob_t_a_given_s(self):
  52. # arrange
  53. src_sentence = ["ich", 'esse', 'ja', 'gern', 'räucherschinken']
  54. trg_sentence = ['i', 'love', 'to', 'eat', 'smoked', 'ham']
  55. src_classes = {'räucherschinken': 0, 'ja': 1, 'ich': 2, 'esse': 3, 'gern': 4}
  56. trg_classes = {'ham': 0, 'smoked': 1, 'i': 3, 'love': 4, 'to': 2, 'eat': 4}
  57. corpus = [AlignedSent(trg_sentence, src_sentence)]
  58. alignment_info = AlignmentInfo(
  59. (0, 1, 4, 0, 2, 5, 5),
  60. [None] + src_sentence,
  61. ['UNUSED'] + trg_sentence,
  62. [[3], [1], [4], [], [2], [5, 6]],
  63. )
  64. head_vacancy_table = defaultdict(
  65. lambda: defaultdict(lambda: defaultdict(float))
  66. )
  67. head_vacancy_table[1 - 0][6][3] = 0.97 # ich -> i
  68. head_vacancy_table[3 - 0][5][4] = 0.97 # esse -> eat
  69. head_vacancy_table[1 - 2][4][4] = 0.97 # gern -> love
  70. head_vacancy_table[2 - 0][2][1] = 0.97 # räucherschinken -> smoked
  71. non_head_vacancy_table = defaultdict(
  72. lambda: defaultdict(lambda: defaultdict(float))
  73. )
  74. non_head_vacancy_table[1 - 0][1][0] = 0.96 # räucherschinken -> ham
  75. translation_table = defaultdict(lambda: defaultdict(float))
  76. translation_table['i']['ich'] = 0.98
  77. translation_table['love']['gern'] = 0.98
  78. translation_table['to'][None] = 0.98
  79. translation_table['eat']['esse'] = 0.98
  80. translation_table['smoked']['räucherschinken'] = 0.98
  81. translation_table['ham']['räucherschinken'] = 0.98
  82. fertility_table = defaultdict(lambda: defaultdict(float))
  83. fertility_table[1]['ich'] = 0.99
  84. fertility_table[1]['esse'] = 0.99
  85. fertility_table[0]['ja'] = 0.99
  86. fertility_table[1]['gern'] = 0.99
  87. fertility_table[2]['räucherschinken'] = 0.999
  88. fertility_table[1][None] = 0.99
  89. probabilities = {
  90. 'p1': 0.167,
  91. 'translation_table': translation_table,
  92. 'fertility_table': fertility_table,
  93. 'head_vacancy_table': head_vacancy_table,
  94. 'non_head_vacancy_table': non_head_vacancy_table,
  95. 'head_distortion_table': None,
  96. 'non_head_distortion_table': None,
  97. 'alignment_table': None,
  98. }
  99. model5 = IBMModel5(corpus, 0, src_classes, trg_classes, probabilities)
  100. # act
  101. probability = model5.prob_t_a_given_s(alignment_info)
  102. # assert
  103. null_generation = 5 * pow(0.167, 1) * pow(0.833, 4)
  104. fertility = 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 2 * 0.999
  105. lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
  106. vacancy = 0.97 * 0.97 * 1 * 0.97 * 0.97 * 0.96
  107. expected_probability = (
  108. null_generation * fertility * lexical_translation * vacancy
  109. )
  110. self.assertEqual(round(probability, 4), round(expected_probability, 4))
  111. def test_prune(self):
  112. # arrange
  113. alignment_infos = [
  114. AlignmentInfo((1, 1), None, None, None),
  115. AlignmentInfo((1, 2), None, None, None),
  116. AlignmentInfo((2, 1), None, None, None),
  117. AlignmentInfo((2, 2), None, None, None),
  118. AlignmentInfo((0, 0), None, None, None),
  119. ]
  120. min_factor = IBMModel5.MIN_SCORE_FACTOR
  121. best_score = 0.9
  122. scores = {
  123. (1, 1): min(min_factor * 1.5, 1) * best_score, # above threshold
  124. (1, 2): best_score,
  125. (2, 1): min_factor * best_score, # at threshold
  126. (2, 2): min_factor * best_score * 0.5, # low score
  127. (0, 0): min(min_factor * 1.1, 1) * 1.2, # above threshold
  128. }
  129. corpus = [AlignedSent(['a'], ['b'])]
  130. original_prob_function = IBMModel4.model4_prob_t_a_given_s
  131. # mock static method
  132. IBMModel4.model4_prob_t_a_given_s = staticmethod(
  133. lambda a, model: scores[a.alignment]
  134. )
  135. model5 = IBMModel5(corpus, 0, None, None)
  136. # act
  137. pruned_alignments = model5.prune(alignment_infos)
  138. # assert
  139. self.assertEqual(len(pruned_alignments), 3)
  140. # restore static method
  141. IBMModel4.model4_prob_t_a_given_s = original_prob_function