vocabulary.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Natural Language Toolkit
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """Language Model Vocabulary"""
  8. import sys
  9. from collections import Counter
  10. from collections.abc import Iterable
  11. from itertools import chain
  12. from functools import singledispatch
  13. @singledispatch
  14. def _dispatched_lookup(words, vocab):
  15. raise TypeError(
  16. "Unsupported type for looking up in vocabulary: {0}".format(type(words))
  17. )
  18. @_dispatched_lookup.register(Iterable)
  19. def _(words, vocab):
  20. """Look up a sequence of words in the vocabulary.
  21. Returns an iterator over looked up words.
  22. """
  23. return tuple(_dispatched_lookup(w, vocab) for w in words)
  24. @_dispatched_lookup.register(str)
  25. def _string_lookup(word, vocab):
  26. """Looks up one word in the vocabulary."""
  27. return word if word in vocab else vocab.unk_label
  28. class Vocabulary:
  29. """Stores language model vocabulary.
  30. Satisfies two common language modeling requirements for a vocabulary:
  31. - When checking membership and calculating its size, filters items
  32. by comparing their counts to a cutoff value.
  33. - Adds a special "unknown" token which unseen words are mapped to.
  34. >>> words = ['a', 'c', '-', 'd', 'c', 'a', 'b', 'r', 'a', 'c', 'd']
  35. >>> from nltk.lm import Vocabulary
  36. >>> vocab = Vocabulary(words, unk_cutoff=2)
  37. Tokens with counts greater than or equal to the cutoff value will
  38. be considered part of the vocabulary.
  39. >>> vocab['c']
  40. 3
  41. >>> 'c' in vocab
  42. True
  43. >>> vocab['d']
  44. 2
  45. >>> 'd' in vocab
  46. True
  47. Tokens with frequency counts less than the cutoff value will be considered not
  48. part of the vocabulary even though their entries in the count dictionary are
  49. preserved.
  50. >>> vocab['b']
  51. 1
  52. >>> 'b' in vocab
  53. False
  54. >>> vocab['aliens']
  55. 0
  56. >>> 'aliens' in vocab
  57. False
  58. Keeping the count entries for seen words allows us to change the cutoff value
  59. without having to recalculate the counts.
  60. >>> vocab2 = Vocabulary(vocab.counts, unk_cutoff=1)
  61. >>> "b" in vocab2
  62. True
  63. The cutoff value influences not only membership checking but also the result of
  64. getting the size of the vocabulary using the built-in `len`.
  65. Note that while the number of keys in the vocabulary's counter stays the same,
  66. the items in the vocabulary differ depending on the cutoff.
  67. We use `sorted` to demonstrate because it keeps the order consistent.
  68. >>> sorted(vocab2.counts)
  69. ['-', 'a', 'b', 'c', 'd', 'r']
  70. >>> sorted(vocab2)
  71. ['-', '<UNK>', 'a', 'b', 'c', 'd', 'r']
  72. >>> sorted(vocab.counts)
  73. ['-', 'a', 'b', 'c', 'd', 'r']
  74. >>> sorted(vocab)
  75. ['<UNK>', 'a', 'c', 'd']
  76. In addition to items it gets populated with, the vocabulary stores a special
  77. token that stands in for so-called "unknown" items. By default it's "<UNK>".
  78. >>> "<UNK>" in vocab
  79. True
  80. We can look up words in a vocabulary using its `lookup` method.
  81. "Unseen" words (with counts less than cutoff) are looked up as the unknown label.
  82. If given one word (a string) as an input, this method will return a string.
  83. >>> vocab.lookup("a")
  84. 'a'
  85. >>> vocab.lookup("aliens")
  86. '<UNK>'
  87. If given a sequence, it will return an tuple of the looked up words.
  88. >>> vocab.lookup(["p", 'a', 'r', 'd', 'b', 'c'])
  89. ('<UNK>', 'a', '<UNK>', 'd', '<UNK>', 'c')
  90. It's possible to update the counts after the vocabulary has been created.
  91. In general, the interface is the same as that of `collections.Counter`.
  92. >>> vocab['b']
  93. 1
  94. >>> vocab.update(["b", "b", "c"])
  95. >>> vocab['b']
  96. 3
  97. """
  98. def __init__(self, counts=None, unk_cutoff=1, unk_label="<UNK>"):
  99. """Create a new Vocabulary.
  100. :param counts: Optional iterable or `collections.Counter` instance to
  101. pre-seed the Vocabulary. In case it is iterable, counts
  102. are calculated.
  103. :param int unk_cutoff: Words that occur less frequently than this value
  104. are not considered part of the vocabulary.
  105. :param unk_label: Label for marking words not part of vocabulary.
  106. """
  107. if isinstance(counts, Counter):
  108. self.counts = counts
  109. else:
  110. self.counts = Counter()
  111. if isinstance(counts, Iterable):
  112. self.counts.update(counts)
  113. self.unk_label = unk_label
  114. if unk_cutoff < 1:
  115. raise ValueError(
  116. "Cutoff value cannot be less than 1. Got: {0}".format(unk_cutoff)
  117. )
  118. self._cutoff = unk_cutoff
  119. @property
  120. def cutoff(self):
  121. """Cutoff value.
  122. Items with count below this value are not considered part of vocabulary.
  123. """
  124. return self._cutoff
  125. def update(self, *counter_args, **counter_kwargs):
  126. """Update vocabulary counts.
  127. Wraps `collections.Counter.update` method.
  128. """
  129. self.counts.update(*counter_args, **counter_kwargs)
  130. def lookup(self, words):
  131. """Look up one or more words in the vocabulary.
  132. If passed one word as a string will return that word or `self.unk_label`.
  133. Otherwise will assume it was passed a sequence of words, will try to look
  134. each of them up and return an iterator over the looked up words.
  135. :param words: Word(s) to look up.
  136. :type words: Iterable(str) or str
  137. :rtype: generator(str) or str
  138. :raises: TypeError for types other than strings or iterables
  139. >>> from nltk.lm import Vocabulary
  140. >>> vocab = Vocabulary(["a", "b", "c", "a", "b"], unk_cutoff=2)
  141. >>> vocab.lookup("a")
  142. 'a'
  143. >>> vocab.lookup("aliens")
  144. '<UNK>'
  145. >>> vocab.lookup(["a", "b", "c", ["x", "b"]])
  146. ('a', 'b', '<UNK>', ('<UNK>', 'b'))
  147. """
  148. return _dispatched_lookup(words, self)
  149. def __getitem__(self, item):
  150. return self._cutoff if item == self.unk_label else self.counts[item]
  151. def __contains__(self, item):
  152. """Only consider items with counts GE to cutoff as being in the
  153. vocabulary."""
  154. return self[item] >= self.cutoff
  155. def __iter__(self):
  156. """Building on membership check define how to iterate over
  157. vocabulary."""
  158. return chain(
  159. (item for item in self.counts if item in self),
  160. [self.unk_label] if self.counts else [],
  161. )
  162. def __len__(self):
  163. """Computing size of vocabulary reflects the cutoff."""
  164. return sum(1 for _ in self)
  165. def __eq__(self, other):
  166. return (
  167. self.unk_label == other.unk_label
  168. and self.cutoff == other.cutoff
  169. and self.counts == other.counts
  170. )
  171. def __str__(self):
  172. return "<{0} with cutoff={1} unk_label='{2}' and {3} items>".format(
  173. self.__class__.__name__, self.cutoff, self.unk_label, len(self)
  174. )