stack_decoder.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: Stack decoder
  3. #
  4. # Copyright (C) 2001-2020 NLTK Project
  5. # Author: Tah Wei Hoon <hoon.tw@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. """
  9. A decoder that uses stacks to implement phrase-based translation.
  10. In phrase-based translation, the source sentence is segmented into
  11. phrases of one or more words, and translations for those phrases are
  12. used to build the target sentence.
  13. Hypothesis data structures are used to keep track of the source words
  14. translated so far and the partial output. A hypothesis can be expanded
  15. by selecting an untranslated phrase, looking up its translation in a
  16. phrase table, and appending that translation to the partial output.
  17. Translation is complete when a hypothesis covers all source words.
  18. The search space is huge because the source sentence can be segmented
  19. in different ways, the source phrases can be selected in any order,
  20. and there could be multiple translations for the same source phrase in
  21. the phrase table. To make decoding tractable, stacks are used to limit
  22. the number of candidate hypotheses by doing histogram and/or threshold
  23. pruning.
  24. Hypotheses with the same number of words translated are placed in the
  25. same stack. In histogram pruning, each stack has a size limit, and
  26. the hypothesis with the lowest score is removed when the stack is full.
  27. In threshold pruning, hypotheses that score below a certain threshold
  28. of the best hypothesis in that stack are removed.
  29. Hypothesis scoring can include various factors such as phrase
  30. translation probability, language model probability, length of
  31. translation, cost of remaining words to be translated, and so on.
  32. References:
  33. Philipp Koehn. 2010. Statistical Machine Translation.
  34. Cambridge University Press, New York.
  35. """
  36. import warnings
  37. from collections import defaultdict
  38. from math import log
  39. class StackDecoder(object):
  40. """
  41. Phrase-based stack decoder for machine translation
  42. >>> from nltk.translate import PhraseTable
  43. >>> phrase_table = PhraseTable()
  44. >>> phrase_table.add(('niemand',), ('nobody',), log(0.8))
  45. >>> phrase_table.add(('niemand',), ('no', 'one'), log(0.2))
  46. >>> phrase_table.add(('erwartet',), ('expects',), log(0.8))
  47. >>> phrase_table.add(('erwartet',), ('expecting',), log(0.2))
  48. >>> phrase_table.add(('niemand', 'erwartet'), ('one', 'does', 'not', 'expect'), log(0.1))
  49. >>> phrase_table.add(('die', 'spanische', 'inquisition'), ('the', 'spanish', 'inquisition'), log(0.8))
  50. >>> phrase_table.add(('!',), ('!',), log(0.8))
  51. >>> # nltk.model should be used here once it is implemented
  52. >>> from collections import defaultdict
  53. >>> language_prob = defaultdict(lambda: -999.0)
  54. >>> language_prob[('nobody',)] = log(0.5)
  55. >>> language_prob[('expects',)] = log(0.4)
  56. >>> language_prob[('the', 'spanish', 'inquisition')] = log(0.2)
  57. >>> language_prob[('!',)] = log(0.1)
  58. >>> language_model = type('',(object,),{'probability_change': lambda self, context, phrase: language_prob[phrase], 'probability': lambda self, phrase: language_prob[phrase]})()
  59. >>> stack_decoder = StackDecoder(phrase_table, language_model)
  60. >>> stack_decoder.translate(['niemand', 'erwartet', 'die', 'spanische', 'inquisition', '!'])
  61. ['nobody', 'expects', 'the', 'spanish', 'inquisition', '!']
  62. """
  63. def __init__(self, phrase_table, language_model):
  64. """
  65. :param phrase_table: Table of translations for source language
  66. phrases and the log probabilities for those translations.
  67. :type phrase_table: PhraseTable
  68. :param language_model: Target language model. Must define a
  69. ``probability_change`` method that calculates the change in
  70. log probability of a sentence, if a given string is appended
  71. to it.
  72. This interface is experimental and will likely be replaced
  73. with nltk.model once it is implemented.
  74. :type language_model: object
  75. """
  76. self.phrase_table = phrase_table
  77. self.language_model = language_model
  78. self.word_penalty = 0.0
  79. """
  80. float: Influences the translation length exponentially.
  81. If positive, shorter translations are preferred.
  82. If negative, longer translations are preferred.
  83. If zero, no penalty is applied.
  84. """
  85. self.beam_threshold = 0.0
  86. """
  87. float: Hypotheses that score below this factor of the best
  88. hypothesis in a stack are dropped from consideration.
  89. Value between 0.0 and 1.0.
  90. """
  91. self.stack_size = 100
  92. """
  93. int: Maximum number of hypotheses to consider in a stack.
  94. Higher values increase the likelihood of a good translation,
  95. but increases processing time.
  96. """
  97. self.__distortion_factor = 0.5
  98. self.__compute_log_distortion()
  99. @property
  100. def distortion_factor(self):
  101. """
  102. float: Amount of reordering of source phrases.
  103. Lower values favour monotone translation, suitable when
  104. word order is similar for both source and target languages.
  105. Value between 0.0 and 1.0. Default 0.5.
  106. """
  107. return self.__distortion_factor
  108. @distortion_factor.setter
  109. def distortion_factor(self, d):
  110. self.__distortion_factor = d
  111. self.__compute_log_distortion()
  112. def __compute_log_distortion(self):
  113. # cache log(distortion_factor) so we don't have to recompute it
  114. # when scoring hypotheses
  115. if self.__distortion_factor == 0.0:
  116. self.__log_distortion_factor = log(1e-9) # 1e-9 is almost zero
  117. else:
  118. self.__log_distortion_factor = log(self.__distortion_factor)
  119. def translate(self, src_sentence):
  120. """
  121. :param src_sentence: Sentence to be translated
  122. :type src_sentence: list(str)
  123. :return: Translated sentence
  124. :rtype: list(str)
  125. """
  126. sentence = tuple(src_sentence) # prevent accidental modification
  127. sentence_length = len(sentence)
  128. stacks = [
  129. _Stack(self.stack_size, self.beam_threshold)
  130. for _ in range(0, sentence_length + 1)
  131. ]
  132. empty_hypothesis = _Hypothesis()
  133. stacks[0].push(empty_hypothesis)
  134. all_phrases = self.find_all_src_phrases(sentence)
  135. future_score_table = self.compute_future_scores(sentence)
  136. for stack in stacks:
  137. for hypothesis in stack:
  138. possible_expansions = StackDecoder.valid_phrases(
  139. all_phrases, hypothesis
  140. )
  141. for src_phrase_span in possible_expansions:
  142. src_phrase = sentence[src_phrase_span[0] : src_phrase_span[1]]
  143. for translation_option in self.phrase_table.translations_for(
  144. src_phrase
  145. ):
  146. raw_score = self.expansion_score(
  147. hypothesis, translation_option, src_phrase_span
  148. )
  149. new_hypothesis = _Hypothesis(
  150. raw_score=raw_score,
  151. src_phrase_span=src_phrase_span,
  152. trg_phrase=translation_option.trg_phrase,
  153. previous=hypothesis,
  154. )
  155. new_hypothesis.future_score = self.future_score(
  156. new_hypothesis, future_score_table, sentence_length
  157. )
  158. total_words = new_hypothesis.total_translated_words()
  159. stacks[total_words].push(new_hypothesis)
  160. if not stacks[sentence_length]:
  161. warnings.warn(
  162. "Unable to translate all words. "
  163. "The source sentence contains words not in "
  164. "the phrase table"
  165. )
  166. # Instead of returning empty output, perhaps a partial
  167. # translation could be returned
  168. return []
  169. best_hypothesis = stacks[sentence_length].best()
  170. return best_hypothesis.translation_so_far()
  171. def find_all_src_phrases(self, src_sentence):
  172. """
  173. Finds all subsequences in src_sentence that have a phrase
  174. translation in the translation table
  175. :type src_sentence: tuple(str)
  176. :return: Subsequences that have a phrase translation,
  177. represented as a table of lists of end positions.
  178. For example, if result[2] is [5, 6, 9], then there are
  179. three phrases starting from position 2 in ``src_sentence``,
  180. ending at positions 5, 6, and 9 exclusive. The list of
  181. ending positions are in ascending order.
  182. :rtype: list(list(int))
  183. """
  184. sentence_length = len(src_sentence)
  185. phrase_indices = [[] for _ in src_sentence]
  186. for start in range(0, sentence_length):
  187. for end in range(start + 1, sentence_length + 1):
  188. potential_phrase = src_sentence[start:end]
  189. if potential_phrase in self.phrase_table:
  190. phrase_indices[start].append(end)
  191. return phrase_indices
  192. def compute_future_scores(self, src_sentence):
  193. """
  194. Determines the approximate scores for translating every
  195. subsequence in ``src_sentence``
  196. Future scores can be used a look-ahead to determine the
  197. difficulty of translating the remaining parts of a src_sentence.
  198. :type src_sentence: tuple(str)
  199. :return: Scores of subsequences referenced by their start and
  200. end positions. For example, result[2][5] is the score of the
  201. subsequence covering positions 2, 3, and 4.
  202. :rtype: dict(int: (dict(int): float))
  203. """
  204. scores = defaultdict(lambda: defaultdict(lambda: float("-inf")))
  205. for seq_length in range(1, len(src_sentence) + 1):
  206. for start in range(0, len(src_sentence) - seq_length + 1):
  207. end = start + seq_length
  208. phrase = src_sentence[start:end]
  209. if phrase in self.phrase_table:
  210. score = self.phrase_table.translations_for(phrase)[
  211. 0
  212. ].log_prob # pick best (first) translation
  213. # Warning: API of language_model is subject to change
  214. score += self.language_model.probability(phrase)
  215. scores[start][end] = score
  216. # check if a better score can be obtained by combining
  217. # two child subsequences
  218. for mid in range(start + 1, end):
  219. combined_score = scores[start][mid] + scores[mid][end]
  220. if combined_score > scores[start][end]:
  221. scores[start][end] = combined_score
  222. return scores
  223. def future_score(self, hypothesis, future_score_table, sentence_length):
  224. """
  225. Determines the approximate score for translating the
  226. untranslated words in ``hypothesis``
  227. """
  228. score = 0.0
  229. for span in hypothesis.untranslated_spans(sentence_length):
  230. score += future_score_table[span[0]][span[1]]
  231. return score
  232. def expansion_score(self, hypothesis, translation_option, src_phrase_span):
  233. """
  234. Calculate the score of expanding ``hypothesis`` with
  235. ``translation_option``
  236. :param hypothesis: Hypothesis being expanded
  237. :type hypothesis: _Hypothesis
  238. :param translation_option: Information about the proposed expansion
  239. :type translation_option: PhraseTableEntry
  240. :param src_phrase_span: Word position span of the source phrase
  241. :type src_phrase_span: tuple(int, int)
  242. """
  243. score = hypothesis.raw_score
  244. score += translation_option.log_prob
  245. # The API of language_model is subject to change; it could accept
  246. # a string, a list of words, and/or some other type
  247. score += self.language_model.probability_change(
  248. hypothesis, translation_option.trg_phrase
  249. )
  250. score += self.distortion_score(hypothesis, src_phrase_span)
  251. score -= self.word_penalty * len(translation_option.trg_phrase)
  252. return score
  253. def distortion_score(self, hypothesis, next_src_phrase_span):
  254. if not hypothesis.src_phrase_span:
  255. return 0.0
  256. next_src_phrase_start = next_src_phrase_span[0]
  257. prev_src_phrase_end = hypothesis.src_phrase_span[1]
  258. distortion_distance = next_src_phrase_start - prev_src_phrase_end
  259. return abs(distortion_distance) * self.__log_distortion_factor
  260. @staticmethod
  261. def valid_phrases(all_phrases_from, hypothesis):
  262. """
  263. Extract phrases from ``all_phrases_from`` that contains words
  264. that have not been translated by ``hypothesis``
  265. :param all_phrases_from: Phrases represented by their spans, in
  266. the same format as the return value of
  267. ``find_all_src_phrases``
  268. :type all_phrases_from: list(list(int))
  269. :type hypothesis: _Hypothesis
  270. :return: A list of phrases, represented by their spans, that
  271. cover untranslated positions.
  272. :rtype: list(tuple(int, int))
  273. """
  274. untranslated_spans = hypothesis.untranslated_spans(len(all_phrases_from))
  275. valid_phrases = []
  276. for available_span in untranslated_spans:
  277. start = available_span[0]
  278. available_end = available_span[1]
  279. while start < available_end:
  280. for phrase_end in all_phrases_from[start]:
  281. if phrase_end > available_end:
  282. # Subsequent elements in all_phrases_from[start]
  283. # will also be > available_end, since the
  284. # elements are in ascending order
  285. break
  286. valid_phrases.append((start, phrase_end))
  287. start += 1
  288. return valid_phrases
  289. class _Hypothesis(object):
  290. """
  291. Partial solution to a translation.
  292. Records the word positions of the phrase being translated, its
  293. translation, raw score, and the cost of the untranslated parts of
  294. the sentence. When the next phrase is selected to build upon the
  295. partial solution, a new _Hypothesis object is created, with a back
  296. pointer to the previous hypothesis.
  297. To find out which words have been translated so far, look at the
  298. ``src_phrase_span`` in the hypothesis chain. Similarly, the
  299. translation output can be found by traversing up the chain.
  300. """
  301. def __init__(
  302. self,
  303. raw_score=0.0,
  304. src_phrase_span=(),
  305. trg_phrase=(),
  306. previous=None,
  307. future_score=0.0,
  308. ):
  309. """
  310. :param raw_score: Likelihood of hypothesis so far.
  311. Higher is better. Does not account for untranslated words.
  312. :type raw_score: float
  313. :param src_phrase_span: Span of word positions covered by the
  314. source phrase in this hypothesis expansion. For example,
  315. (2, 5) means that the phrase is from the second word up to,
  316. but not including the fifth word in the source sentence.
  317. :type src_phrase_span: tuple(int)
  318. :param trg_phrase: Translation of the source phrase in this
  319. hypothesis expansion
  320. :type trg_phrase: tuple(str)
  321. :param previous: Previous hypothesis before expansion to this one
  322. :type previous: _Hypothesis
  323. :param future_score: Approximate score for translating the
  324. remaining words not covered by this hypothesis. Higher means
  325. that the remaining words are easier to translate.
  326. :type future_score: float
  327. """
  328. self.raw_score = raw_score
  329. self.src_phrase_span = src_phrase_span
  330. self.trg_phrase = trg_phrase
  331. self.previous = previous
  332. self.future_score = future_score
  333. def score(self):
  334. """
  335. Overall score of hypothesis after accounting for local and
  336. global features
  337. """
  338. return self.raw_score + self.future_score
  339. def untranslated_spans(self, sentence_length):
  340. """
  341. Starting from each untranslated word, find the longest
  342. continuous span of untranslated positions
  343. :param sentence_length: Length of source sentence being
  344. translated by the hypothesis
  345. :type sentence_length: int
  346. :rtype: list(tuple(int, int))
  347. """
  348. translated_positions = self.translated_positions()
  349. translated_positions.sort()
  350. translated_positions.append(sentence_length) # add sentinel position
  351. untranslated_spans = []
  352. start = 0
  353. # each untranslated span must end in one of the translated_positions
  354. for end in translated_positions:
  355. if start < end:
  356. untranslated_spans.append((start, end))
  357. start = end + 1
  358. return untranslated_spans
  359. def translated_positions(self):
  360. """
  361. List of positions in the source sentence of words already
  362. translated. The list is not sorted.
  363. :rtype: list(int)
  364. """
  365. translated_positions = []
  366. current_hypothesis = self
  367. while current_hypothesis.previous is not None:
  368. translated_span = current_hypothesis.src_phrase_span
  369. translated_positions.extend(range(translated_span[0], translated_span[1]))
  370. current_hypothesis = current_hypothesis.previous
  371. return translated_positions
  372. def total_translated_words(self):
  373. return len(self.translated_positions())
  374. def translation_so_far(self):
  375. translation = []
  376. self.__build_translation(self, translation)
  377. return translation
  378. def __build_translation(self, hypothesis, output):
  379. if hypothesis.previous is None:
  380. return
  381. self.__build_translation(hypothesis.previous, output)
  382. output.extend(hypothesis.trg_phrase)
  383. class _Stack(object):
  384. """
  385. Collection of _Hypothesis objects
  386. """
  387. def __init__(self, max_size=100, beam_threshold=0.0):
  388. """
  389. :param beam_threshold: Hypotheses that score less than this
  390. factor of the best hypothesis are discarded from the stack.
  391. Value must be between 0.0 and 1.0.
  392. :type beam_threshold: float
  393. """
  394. self.max_size = max_size
  395. self.items = []
  396. if beam_threshold == 0.0:
  397. self.__log_beam_threshold = float("-inf")
  398. else:
  399. self.__log_beam_threshold = log(beam_threshold)
  400. def push(self, hypothesis):
  401. """
  402. Add ``hypothesis`` to the stack.
  403. Removes lowest scoring hypothesis if the stack is full.
  404. After insertion, hypotheses that score less than
  405. ``beam_threshold`` times the score of the best hypothesis
  406. are removed.
  407. """
  408. self.items.append(hypothesis)
  409. self.items.sort(key=lambda h: h.score(), reverse=True)
  410. while len(self.items) > self.max_size:
  411. self.items.pop()
  412. self.threshold_prune()
  413. def threshold_prune(self):
  414. if not self.items:
  415. return
  416. # log(score * beam_threshold) = log(score) + log(beam_threshold)
  417. threshold = self.items[0].score() + self.__log_beam_threshold
  418. for hypothesis in reversed(self.items):
  419. if hypothesis.score() < threshold:
  420. self.items.pop()
  421. else:
  422. break
  423. def best(self):
  424. """
  425. :return: Hypothesis with the highest score in the stack
  426. :rtype: _Hypothesis
  427. """
  428. if self.items:
  429. return self.items[0]
  430. return None
  431. def __iter__(self):
  432. return iter(self.items)
  433. def __contains__(self, hypothesis):
  434. return hypothesis in self.items
  435. def __bool__(self):
  436. return len(self.items) != 0
  437. __nonzero__ = __bool__