test_ibm3.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests for IBM Model 3 training methods
  4. """
  5. import unittest
  6. from collections import defaultdict
  7. from nltk.translate import AlignedSent
  8. from nltk.translate import IBMModel
  9. from nltk.translate import IBMModel3
  10. from nltk.translate.ibm_model import AlignmentInfo
  11. class TestIBMModel3(unittest.TestCase):
  12. def test_set_uniform_distortion_probabilities(self):
  13. # arrange
  14. corpus = [
  15. AlignedSent(['ham', 'eggs'], ['schinken', 'schinken', 'eier']),
  16. AlignedSent(['spam', 'spam', 'spam', 'spam'], ['spam', 'spam']),
  17. ]
  18. model3 = IBMModel3(corpus, 0)
  19. # act
  20. model3.set_uniform_probabilities(corpus)
  21. # assert
  22. # expected_prob = 1.0 / length of target sentence
  23. self.assertEqual(model3.distortion_table[1][0][3][2], 1.0 / 2)
  24. self.assertEqual(model3.distortion_table[4][2][2][4], 1.0 / 4)
  25. def test_set_uniform_distortion_probabilities_of_non_domain_values(self):
  26. # arrange
  27. corpus = [
  28. AlignedSent(['ham', 'eggs'], ['schinken', 'schinken', 'eier']),
  29. AlignedSent(['spam', 'spam', 'spam', 'spam'], ['spam', 'spam']),
  30. ]
  31. model3 = IBMModel3(corpus, 0)
  32. # act
  33. model3.set_uniform_probabilities(corpus)
  34. # assert
  35. # examine i and j values that are not in the training data domain
  36. self.assertEqual(model3.distortion_table[0][0][3][2], IBMModel.MIN_PROB)
  37. self.assertEqual(model3.distortion_table[9][2][2][4], IBMModel.MIN_PROB)
  38. self.assertEqual(model3.distortion_table[2][9][2][4], IBMModel.MIN_PROB)
  39. def test_prob_t_a_given_s(self):
  40. # arrange
  41. src_sentence = ["ich", 'esse', 'ja', 'gern', 'räucherschinken']
  42. trg_sentence = ['i', 'love', 'to', 'eat', 'smoked', 'ham']
  43. corpus = [AlignedSent(trg_sentence, src_sentence)]
  44. alignment_info = AlignmentInfo(
  45. (0, 1, 4, 0, 2, 5, 5),
  46. [None] + src_sentence,
  47. ['UNUSED'] + trg_sentence,
  48. [[3], [1], [4], [], [2], [5, 6]],
  49. )
  50. distortion_table = defaultdict(
  51. lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
  52. )
  53. distortion_table[1][1][5][6] = 0.97 # i -> ich
  54. distortion_table[2][4][5][6] = 0.97 # love -> gern
  55. distortion_table[3][0][5][6] = 0.97 # to -> NULL
  56. distortion_table[4][2][5][6] = 0.97 # eat -> esse
  57. distortion_table[5][5][5][6] = 0.97 # smoked -> räucherschinken
  58. distortion_table[6][5][5][6] = 0.97 # ham -> räucherschinken
  59. translation_table = defaultdict(lambda: defaultdict(float))
  60. translation_table['i']['ich'] = 0.98
  61. translation_table['love']['gern'] = 0.98
  62. translation_table['to'][None] = 0.98
  63. translation_table['eat']['esse'] = 0.98
  64. translation_table['smoked']['räucherschinken'] = 0.98
  65. translation_table['ham']['räucherschinken'] = 0.98
  66. fertility_table = defaultdict(lambda: defaultdict(float))
  67. fertility_table[1]['ich'] = 0.99
  68. fertility_table[1]['esse'] = 0.99
  69. fertility_table[0]['ja'] = 0.99
  70. fertility_table[1]['gern'] = 0.99
  71. fertility_table[2]['räucherschinken'] = 0.999
  72. fertility_table[1][None] = 0.99
  73. probabilities = {
  74. 'p1': 0.167,
  75. 'translation_table': translation_table,
  76. 'distortion_table': distortion_table,
  77. 'fertility_table': fertility_table,
  78. 'alignment_table': None,
  79. }
  80. model3 = IBMModel3(corpus, 0, probabilities)
  81. # act
  82. probability = model3.prob_t_a_given_s(alignment_info)
  83. # assert
  84. null_generation = 5 * pow(0.167, 1) * pow(0.833, 4)
  85. fertility = 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 1 * 0.99 * 2 * 0.999
  86. lexical_translation = 0.98 * 0.98 * 0.98 * 0.98 * 0.98 * 0.98
  87. distortion = 0.97 * 0.97 * 0.97 * 0.97 * 0.97 * 0.97
  88. expected_probability = (
  89. null_generation * fertility * lexical_translation * distortion
  90. )
  91. self.assertEqual(round(probability, 4), round(expected_probability, 4))