| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516 |
- # -*- coding: utf-8 -*-
- # Natural Language Toolkit: Stack decoder
- #
- # Copyright (C) 2001-2020 NLTK Project
- # Author: Tah Wei Hoon <hoon.tw@gmail.com>
- # URL: <http://nltk.org/>
- # For license information, see LICENSE.TXT
- """
- A decoder that uses stacks to implement phrase-based translation.
- In phrase-based translation, the source sentence is segmented into
- phrases of one or more words, and translations for those phrases are
- used to build the target sentence.
- Hypothesis data structures are used to keep track of the source words
- translated so far and the partial output. A hypothesis can be expanded
- by selecting an untranslated phrase, looking up its translation in a
- phrase table, and appending that translation to the partial output.
- Translation is complete when a hypothesis covers all source words.
- The search space is huge because the source sentence can be segmented
- in different ways, the source phrases can be selected in any order,
- and there could be multiple translations for the same source phrase in
- the phrase table. To make decoding tractable, stacks are used to limit
- the number of candidate hypotheses by doing histogram and/or threshold
- pruning.
- Hypotheses with the same number of words translated are placed in the
- same stack. In histogram pruning, each stack has a size limit, and
- the hypothesis with the lowest score is removed when the stack is full.
- In threshold pruning, hypotheses that score below a certain threshold
- of the best hypothesis in that stack are removed.
- Hypothesis scoring can include various factors such as phrase
- translation probability, language model probability, length of
- translation, cost of remaining words to be translated, and so on.
- References:
- Philipp Koehn. 2010. Statistical Machine Translation.
- Cambridge University Press, New York.
- """
- import warnings
- from collections import defaultdict
- from math import log
- class StackDecoder(object):
- """
- Phrase-based stack decoder for machine translation
- >>> from nltk.translate import PhraseTable
- >>> phrase_table = PhraseTable()
- >>> phrase_table.add(('niemand',), ('nobody',), log(0.8))
- >>> phrase_table.add(('niemand',), ('no', 'one'), log(0.2))
- >>> phrase_table.add(('erwartet',), ('expects',), log(0.8))
- >>> phrase_table.add(('erwartet',), ('expecting',), log(0.2))
- >>> phrase_table.add(('niemand', 'erwartet'), ('one', 'does', 'not', 'expect'), log(0.1))
- >>> phrase_table.add(('die', 'spanische', 'inquisition'), ('the', 'spanish', 'inquisition'), log(0.8))
- >>> phrase_table.add(('!',), ('!',), log(0.8))
- >>> # nltk.model should be used here once it is implemented
- >>> from collections import defaultdict
- >>> language_prob = defaultdict(lambda: -999.0)
- >>> language_prob[('nobody',)] = log(0.5)
- >>> language_prob[('expects',)] = log(0.4)
- >>> language_prob[('the', 'spanish', 'inquisition')] = log(0.2)
- >>> language_prob[('!',)] = log(0.1)
- >>> language_model = type('',(object,),{'probability_change': lambda self, context, phrase: language_prob[phrase], 'probability': lambda self, phrase: language_prob[phrase]})()
- >>> stack_decoder = StackDecoder(phrase_table, language_model)
- >>> stack_decoder.translate(['niemand', 'erwartet', 'die', 'spanische', 'inquisition', '!'])
- ['nobody', 'expects', 'the', 'spanish', 'inquisition', '!']
- """
- def __init__(self, phrase_table, language_model):
- """
- :param phrase_table: Table of translations for source language
- phrases and the log probabilities for those translations.
- :type phrase_table: PhraseTable
- :param language_model: Target language model. Must define a
- ``probability_change`` method that calculates the change in
- log probability of a sentence, if a given string is appended
- to it.
- This interface is experimental and will likely be replaced
- with nltk.model once it is implemented.
- :type language_model: object
- """
- self.phrase_table = phrase_table
- self.language_model = language_model
- self.word_penalty = 0.0
- """
- float: Influences the translation length exponentially.
- If positive, shorter translations are preferred.
- If negative, longer translations are preferred.
- If zero, no penalty is applied.
- """
- self.beam_threshold = 0.0
- """
- float: Hypotheses that score below this factor of the best
- hypothesis in a stack are dropped from consideration.
- Value between 0.0 and 1.0.
- """
- self.stack_size = 100
- """
- int: Maximum number of hypotheses to consider in a stack.
- Higher values increase the likelihood of a good translation,
- but increases processing time.
- """
- self.__distortion_factor = 0.5
- self.__compute_log_distortion()
- @property
- def distortion_factor(self):
- """
- float: Amount of reordering of source phrases.
- Lower values favour monotone translation, suitable when
- word order is similar for both source and target languages.
- Value between 0.0 and 1.0. Default 0.5.
- """
- return self.__distortion_factor
- @distortion_factor.setter
- def distortion_factor(self, d):
- self.__distortion_factor = d
- self.__compute_log_distortion()
- def __compute_log_distortion(self):
- # cache log(distortion_factor) so we don't have to recompute it
- # when scoring hypotheses
- if self.__distortion_factor == 0.0:
- self.__log_distortion_factor = log(1e-9) # 1e-9 is almost zero
- else:
- self.__log_distortion_factor = log(self.__distortion_factor)
- def translate(self, src_sentence):
- """
- :param src_sentence: Sentence to be translated
- :type src_sentence: list(str)
- :return: Translated sentence
- :rtype: list(str)
- """
- sentence = tuple(src_sentence) # prevent accidental modification
- sentence_length = len(sentence)
- stacks = [
- _Stack(self.stack_size, self.beam_threshold)
- for _ in range(0, sentence_length + 1)
- ]
- empty_hypothesis = _Hypothesis()
- stacks[0].push(empty_hypothesis)
- all_phrases = self.find_all_src_phrases(sentence)
- future_score_table = self.compute_future_scores(sentence)
- for stack in stacks:
- for hypothesis in stack:
- possible_expansions = StackDecoder.valid_phrases(
- all_phrases, hypothesis
- )
- for src_phrase_span in possible_expansions:
- src_phrase = sentence[src_phrase_span[0] : src_phrase_span[1]]
- for translation_option in self.phrase_table.translations_for(
- src_phrase
- ):
- raw_score = self.expansion_score(
- hypothesis, translation_option, src_phrase_span
- )
- new_hypothesis = _Hypothesis(
- raw_score=raw_score,
- src_phrase_span=src_phrase_span,
- trg_phrase=translation_option.trg_phrase,
- previous=hypothesis,
- )
- new_hypothesis.future_score = self.future_score(
- new_hypothesis, future_score_table, sentence_length
- )
- total_words = new_hypothesis.total_translated_words()
- stacks[total_words].push(new_hypothesis)
- if not stacks[sentence_length]:
- warnings.warn(
- "Unable to translate all words. "
- "The source sentence contains words not in "
- "the phrase table"
- )
- # Instead of returning empty output, perhaps a partial
- # translation could be returned
- return []
- best_hypothesis = stacks[sentence_length].best()
- return best_hypothesis.translation_so_far()
- def find_all_src_phrases(self, src_sentence):
- """
- Finds all subsequences in src_sentence that have a phrase
- translation in the translation table
- :type src_sentence: tuple(str)
- :return: Subsequences that have a phrase translation,
- represented as a table of lists of end positions.
- For example, if result[2] is [5, 6, 9], then there are
- three phrases starting from position 2 in ``src_sentence``,
- ending at positions 5, 6, and 9 exclusive. The list of
- ending positions are in ascending order.
- :rtype: list(list(int))
- """
- sentence_length = len(src_sentence)
- phrase_indices = [[] for _ in src_sentence]
- for start in range(0, sentence_length):
- for end in range(start + 1, sentence_length + 1):
- potential_phrase = src_sentence[start:end]
- if potential_phrase in self.phrase_table:
- phrase_indices[start].append(end)
- return phrase_indices
- def compute_future_scores(self, src_sentence):
- """
- Determines the approximate scores for translating every
- subsequence in ``src_sentence``
- Future scores can be used a look-ahead to determine the
- difficulty of translating the remaining parts of a src_sentence.
- :type src_sentence: tuple(str)
- :return: Scores of subsequences referenced by their start and
- end positions. For example, result[2][5] is the score of the
- subsequence covering positions 2, 3, and 4.
- :rtype: dict(int: (dict(int): float))
- """
- scores = defaultdict(lambda: defaultdict(lambda: float("-inf")))
- for seq_length in range(1, len(src_sentence) + 1):
- for start in range(0, len(src_sentence) - seq_length + 1):
- end = start + seq_length
- phrase = src_sentence[start:end]
- if phrase in self.phrase_table:
- score = self.phrase_table.translations_for(phrase)[
- 0
- ].log_prob # pick best (first) translation
- # Warning: API of language_model is subject to change
- score += self.language_model.probability(phrase)
- scores[start][end] = score
- # check if a better score can be obtained by combining
- # two child subsequences
- for mid in range(start + 1, end):
- combined_score = scores[start][mid] + scores[mid][end]
- if combined_score > scores[start][end]:
- scores[start][end] = combined_score
- return scores
- def future_score(self, hypothesis, future_score_table, sentence_length):
- """
- Determines the approximate score for translating the
- untranslated words in ``hypothesis``
- """
- score = 0.0
- for span in hypothesis.untranslated_spans(sentence_length):
- score += future_score_table[span[0]][span[1]]
- return score
- def expansion_score(self, hypothesis, translation_option, src_phrase_span):
- """
- Calculate the score of expanding ``hypothesis`` with
- ``translation_option``
- :param hypothesis: Hypothesis being expanded
- :type hypothesis: _Hypothesis
- :param translation_option: Information about the proposed expansion
- :type translation_option: PhraseTableEntry
- :param src_phrase_span: Word position span of the source phrase
- :type src_phrase_span: tuple(int, int)
- """
- score = hypothesis.raw_score
- score += translation_option.log_prob
- # The API of language_model is subject to change; it could accept
- # a string, a list of words, and/or some other type
- score += self.language_model.probability_change(
- hypothesis, translation_option.trg_phrase
- )
- score += self.distortion_score(hypothesis, src_phrase_span)
- score -= self.word_penalty * len(translation_option.trg_phrase)
- return score
- def distortion_score(self, hypothesis, next_src_phrase_span):
- if not hypothesis.src_phrase_span:
- return 0.0
- next_src_phrase_start = next_src_phrase_span[0]
- prev_src_phrase_end = hypothesis.src_phrase_span[1]
- distortion_distance = next_src_phrase_start - prev_src_phrase_end
- return abs(distortion_distance) * self.__log_distortion_factor
- @staticmethod
- def valid_phrases(all_phrases_from, hypothesis):
- """
- Extract phrases from ``all_phrases_from`` that contains words
- that have not been translated by ``hypothesis``
- :param all_phrases_from: Phrases represented by their spans, in
- the same format as the return value of
- ``find_all_src_phrases``
- :type all_phrases_from: list(list(int))
- :type hypothesis: _Hypothesis
- :return: A list of phrases, represented by their spans, that
- cover untranslated positions.
- :rtype: list(tuple(int, int))
- """
- untranslated_spans = hypothesis.untranslated_spans(len(all_phrases_from))
- valid_phrases = []
- for available_span in untranslated_spans:
- start = available_span[0]
- available_end = available_span[1]
- while start < available_end:
- for phrase_end in all_phrases_from[start]:
- if phrase_end > available_end:
- # Subsequent elements in all_phrases_from[start]
- # will also be > available_end, since the
- # elements are in ascending order
- break
- valid_phrases.append((start, phrase_end))
- start += 1
- return valid_phrases
- class _Hypothesis(object):
- """
- Partial solution to a translation.
- Records the word positions of the phrase being translated, its
- translation, raw score, and the cost of the untranslated parts of
- the sentence. When the next phrase is selected to build upon the
- partial solution, a new _Hypothesis object is created, with a back
- pointer to the previous hypothesis.
- To find out which words have been translated so far, look at the
- ``src_phrase_span`` in the hypothesis chain. Similarly, the
- translation output can be found by traversing up the chain.
- """
- def __init__(
- self,
- raw_score=0.0,
- src_phrase_span=(),
- trg_phrase=(),
- previous=None,
- future_score=0.0,
- ):
- """
- :param raw_score: Likelihood of hypothesis so far.
- Higher is better. Does not account for untranslated words.
- :type raw_score: float
- :param src_phrase_span: Span of word positions covered by the
- source phrase in this hypothesis expansion. For example,
- (2, 5) means that the phrase is from the second word up to,
- but not including the fifth word in the source sentence.
- :type src_phrase_span: tuple(int)
- :param trg_phrase: Translation of the source phrase in this
- hypothesis expansion
- :type trg_phrase: tuple(str)
- :param previous: Previous hypothesis before expansion to this one
- :type previous: _Hypothesis
- :param future_score: Approximate score for translating the
- remaining words not covered by this hypothesis. Higher means
- that the remaining words are easier to translate.
- :type future_score: float
- """
- self.raw_score = raw_score
- self.src_phrase_span = src_phrase_span
- self.trg_phrase = trg_phrase
- self.previous = previous
- self.future_score = future_score
- def score(self):
- """
- Overall score of hypothesis after accounting for local and
- global features
- """
- return self.raw_score + self.future_score
- def untranslated_spans(self, sentence_length):
- """
- Starting from each untranslated word, find the longest
- continuous span of untranslated positions
- :param sentence_length: Length of source sentence being
- translated by the hypothesis
- :type sentence_length: int
- :rtype: list(tuple(int, int))
- """
- translated_positions = self.translated_positions()
- translated_positions.sort()
- translated_positions.append(sentence_length) # add sentinel position
- untranslated_spans = []
- start = 0
- # each untranslated span must end in one of the translated_positions
- for end in translated_positions:
- if start < end:
- untranslated_spans.append((start, end))
- start = end + 1
- return untranslated_spans
- def translated_positions(self):
- """
- List of positions in the source sentence of words already
- translated. The list is not sorted.
- :rtype: list(int)
- """
- translated_positions = []
- current_hypothesis = self
- while current_hypothesis.previous is not None:
- translated_span = current_hypothesis.src_phrase_span
- translated_positions.extend(range(translated_span[0], translated_span[1]))
- current_hypothesis = current_hypothesis.previous
- return translated_positions
- def total_translated_words(self):
- return len(self.translated_positions())
- def translation_so_far(self):
- translation = []
- self.__build_translation(self, translation)
- return translation
- def __build_translation(self, hypothesis, output):
- if hypothesis.previous is None:
- return
- self.__build_translation(hypothesis.previous, output)
- output.extend(hypothesis.trg_phrase)
- class _Stack(object):
- """
- Collection of _Hypothesis objects
- """
- def __init__(self, max_size=100, beam_threshold=0.0):
- """
- :param beam_threshold: Hypotheses that score less than this
- factor of the best hypothesis are discarded from the stack.
- Value must be between 0.0 and 1.0.
- :type beam_threshold: float
- """
- self.max_size = max_size
- self.items = []
- if beam_threshold == 0.0:
- self.__log_beam_threshold = float("-inf")
- else:
- self.__log_beam_threshold = log(beam_threshold)
- def push(self, hypothesis):
- """
- Add ``hypothesis`` to the stack.
- Removes lowest scoring hypothesis if the stack is full.
- After insertion, hypotheses that score less than
- ``beam_threshold`` times the score of the best hypothesis
- are removed.
- """
- self.items.append(hypothesis)
- self.items.sort(key=lambda h: h.score(), reverse=True)
- while len(self.items) > self.max_size:
- self.items.pop()
- self.threshold_prune()
- def threshold_prune(self):
- if not self.items:
- return
- # log(score * beam_threshold) = log(score) + log(beam_threshold)
- threshold = self.items[0].score() + self.__log_beam_threshold
- for hypothesis in reversed(self.items):
- if hypothesis.score() < threshold:
- self.items.pop()
- else:
- break
- def best(self):
- """
- :return: Hypothesis with the highest score in the stack
- :rtype: _Hypothesis
- """
- if self.items:
- return self.items[0]
- return None
- def __iter__(self):
- return iter(self.items)
- def __contains__(self, hypothesis):
- return hypothesis in self.items
- def __bool__(self):
- return len(self.items) != 0
- __nonzero__ = __bool__
|