api.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Natural Language Toolkit: Language Models
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Authors: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """Language Model Interface."""
  8. import random
  9. from abc import ABCMeta, abstractmethod
  10. from bisect import bisect
  11. from nltk.lm.counter import NgramCounter
  12. from nltk.lm.util import log_base2
  13. from nltk.lm.vocabulary import Vocabulary
  14. from itertools import accumulate
  15. class Smoothing(metaclass=ABCMeta):
  16. """Ngram Smoothing Interface
  17. Implements Chen & Goodman 1995's idea that all smoothing algorithms have
  18. certain features in common. This should ideally allow smoothing algorithms to
  19. work both with Backoff and Interpolation.
  20. """
  21. def __init__(self, vocabulary, counter):
  22. """
  23. :param vocabulary: The Ngram vocabulary object.
  24. :type vocabulary: nltk.lm.vocab.Vocabulary
  25. :param counter: The counts of the vocabulary items.
  26. :type counter: nltk.lm.counter.NgramCounter
  27. """
  28. self.vocab = vocabulary
  29. self.counts = counter
  30. @abstractmethod
  31. def unigram_score(self, word):
  32. raise NotImplementedError()
  33. @abstractmethod
  34. def alpha_gamma(self, word, context):
  35. raise NotImplementedError()
  36. def _mean(items):
  37. """Return average (aka mean) for sequence of items."""
  38. return sum(items) / len(items)
  39. def _random_generator(seed_or_generator):
  40. if isinstance(seed_or_generator, random.Random):
  41. return seed_or_generator
  42. return random.Random(seed_or_generator)
  43. def _weighted_choice(population, weights, random_generator=None):
  44. """Like random.choice, but with weights.
  45. Heavily inspired by python 3.6 `random.choices`.
  46. """
  47. if not population:
  48. raise ValueError("Can't choose from empty population")
  49. if len(population) != len(weights):
  50. raise ValueError("The number of weights does not match the population")
  51. cum_weights = list(accumulate(weights))
  52. total = cum_weights[-1]
  53. threshold = random_generator.random()
  54. return population[bisect(cum_weights, total * threshold)]
  55. class LanguageModel(metaclass=ABCMeta):
  56. """ABC for Language Models.
  57. Cannot be directly instantiated itself.
  58. """
  59. def __init__(self, order, vocabulary=None, counter=None):
  60. """Creates new LanguageModel.
  61. :param vocabulary: If provided, this vocabulary will be used instead
  62. of creating a new one when training.
  63. :type vocabulary: `nltk.lm.Vocabulary` or None
  64. :param counter: If provided, use this object to count ngrams.
  65. :type vocabulary: `nltk.lm.NgramCounter` or None
  66. :param ngrams_fn: If given, defines how sentences in training text are turned to ngram
  67. sequences.
  68. :type ngrams_fn: function or None
  69. :param pad_fn: If given, defines how senteces in training text are padded.
  70. :type pad_fn: function or None
  71. """
  72. self.order = order
  73. self.vocab = Vocabulary() if vocabulary is None else vocabulary
  74. self.counts = NgramCounter() if counter is None else counter
  75. def fit(self, text, vocabulary_text=None):
  76. """Trains the model on a text.
  77. :param text: Training text as a sequence of sentences.
  78. """
  79. if not self.vocab:
  80. if vocabulary_text is None:
  81. raise ValueError(
  82. "Cannot fit without a vocabulary or text to create it from."
  83. )
  84. self.vocab.update(vocabulary_text)
  85. self.counts.update(self.vocab.lookup(sent) for sent in text)
  86. def score(self, word, context=None):
  87. """Masks out of vocab (OOV) words and computes their model score.
  88. For model-specific logic of calculating scores, see the `unmasked_score`
  89. method.
  90. """
  91. return self.unmasked_score(
  92. self.vocab.lookup(word), self.vocab.lookup(context) if context else None
  93. )
  94. @abstractmethod
  95. def unmasked_score(self, word, context=None):
  96. """Score a word given some optional context.
  97. Concrete models are expected to provide an implementation.
  98. Note that this method does not mask its arguments with the OOV label.
  99. Use the `score` method for that.
  100. :param str word: Word for which we want the score
  101. :param tuple(str) context: Context the word is in.
  102. If `None`, compute unigram score.
  103. :param context: tuple(str) or None
  104. :rtype: float
  105. """
  106. raise NotImplementedError()
  107. def logscore(self, word, context=None):
  108. """Evaluate the log score of this word in this context.
  109. The arguments are the same as for `score` and `unmasked_score`.
  110. """
  111. return log_base2(self.score(word, context))
  112. def context_counts(self, context):
  113. """Helper method for retrieving counts for a given context.
  114. Assumes context has been checked and oov words in it masked.
  115. :type context: tuple(str) or None
  116. """
  117. return (
  118. self.counts[len(context) + 1][context] if context else self.counts.unigrams
  119. )
  120. def entropy(self, text_ngrams):
  121. """Calculate cross-entropy of model for given evaluation text.
  122. :param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
  123. :rtype: float
  124. """
  125. return -1 * _mean(
  126. [self.logscore(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
  127. )
  128. def perplexity(self, text_ngrams):
  129. """Calculates the perplexity of the given text.
  130. This is simply 2 ** cross-entropy for the text, so the arguments are the same.
  131. """
  132. return pow(2.0, self.entropy(text_ngrams))
  133. def generate(self, num_words=1, text_seed=None, random_seed=None):
  134. """Generate words from the model.
  135. :param int num_words: How many words to generate. By default 1.
  136. :param text_seed: Generation can be conditioned on preceding context.
  137. :param random_seed: A random seed or an instance of `random.Random`. If provided,
  138. makes the random sampling part of generation reproducible.
  139. :return: One (str) word or a list of words generated from model.
  140. Examples:
  141. >>> from nltk.lm import MLE
  142. >>> lm = MLE(2)
  143. >>> lm.fit([[("a", "b"), ("b", "c")]], vocabulary_text=['a', 'b', 'c'])
  144. >>> lm.fit([[("a",), ("b",), ("c",)]])
  145. >>> lm.generate(random_seed=3)
  146. 'a'
  147. >>> lm.generate(text_seed=['a'])
  148. 'b'
  149. """
  150. text_seed = [] if text_seed is None else list(text_seed)
  151. random_generator = _random_generator(random_seed)
  152. # This is the base recursion case.
  153. if num_words == 1:
  154. context = (
  155. text_seed[-self.order + 1 :]
  156. if len(text_seed) >= self.order
  157. else text_seed
  158. )
  159. samples = self.context_counts(self.vocab.lookup(context))
  160. while context and not samples:
  161. context = context[1:] if len(context) > 1 else []
  162. samples = self.context_counts(self.vocab.lookup(context))
  163. # Sorting samples achieves two things:
  164. # - reproducible randomness when sampling
  165. # - turns Mapping into Sequence which `_weighted_choice` expects
  166. samples = sorted(samples)
  167. return _weighted_choice(
  168. samples,
  169. tuple(self.score(w, context) for w in samples),
  170. random_generator,
  171. )
  172. # We build up text one word at a time using the preceding context.
  173. generated = []
  174. for _ in range(num_words):
  175. generated.append(
  176. self.generate(
  177. num_words=1,
  178. text_seed=text_seed + generated,
  179. random_seed=random_generator,
  180. )
  181. )
  182. return generated