pipeline.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from __future__ import unicode_literals
  2. import logging
  3. from builtins import str
  4. import six
  5. from lunr.exceptions import BaseLunrException
  6. from lunr.token import Token
  7. log = logging.getLogger(__name__)
  8. class Pipeline:
  9. """lunr.Pipelines maintain a list of functions to be applied to all tokens
  10. in documents entering the search index and queries ran agains the index.
  11. """
  12. registered_functions = {}
  13. def __init__(self):
  14. self._stack = []
  15. def __len__(self):
  16. return len(self._stack)
  17. def __repr__(self):
  18. return '<Pipeline stack="{}">'.format(",".join(fn.label for fn in self._stack))
  19. # TODO: add iterator methods?
  20. @classmethod
  21. def register_function(cls, fn, label):
  22. """Register a function with the pipeline."""
  23. if label in cls.registered_functions:
  24. log.warning("Overwriting existing registered function %s", label)
  25. fn.label = label
  26. cls.registered_functions[fn.label] = fn
  27. @classmethod
  28. def load(cls, serialised):
  29. """Loads a previously serialised pipeline."""
  30. pipeline = cls()
  31. for fn_name in serialised:
  32. try:
  33. fn = cls.registered_functions[fn_name]
  34. except KeyError:
  35. raise BaseLunrException(
  36. "Cannot load unregistered function ".format(fn_name)
  37. )
  38. else:
  39. pipeline.add(fn)
  40. return pipeline
  41. def add(self, *args):
  42. """Adds new functions to the end of the pipeline.
  43. Functions must accept three arguments:
  44. - Token: A lunr.Token object which will be updated
  45. - i: The index of the token in the set
  46. - tokens: A list of tokens representing the set
  47. """
  48. for fn in args:
  49. self.warn_if_function_not_registered(fn)
  50. self._stack.append(fn)
  51. def warn_if_function_not_registered(self, fn):
  52. try:
  53. return fn.label in self.registered_functions
  54. except AttributeError:
  55. log.warning(
  56. 'Function "{}" is not registered with pipeline. '
  57. "This may cause problems when serialising the index.".format(
  58. getattr(fn, "label", fn)
  59. )
  60. )
  61. def after(self, existing_fn, new_fn):
  62. """Adds a single function after a function that already exists in the
  63. pipeline."""
  64. self.warn_if_function_not_registered(new_fn)
  65. try:
  66. index = self._stack.index(existing_fn)
  67. self._stack.insert(index + 1, new_fn)
  68. except ValueError as e:
  69. six.raise_from(BaseLunrException("Cannot find existing_fn"), e)
  70. def before(self, existing_fn, new_fn):
  71. """Adds a single function before a function that already exists in the
  72. pipeline.
  73. """
  74. self.warn_if_function_not_registered(new_fn)
  75. try:
  76. index = self._stack.index(existing_fn)
  77. self._stack.insert(index, new_fn)
  78. except ValueError as e:
  79. six.raise_from(BaseLunrException("Cannot find existing_fn"), e)
  80. def remove(self, fn):
  81. """Removes a function from the pipeline."""
  82. try:
  83. self._stack.remove(fn)
  84. except ValueError:
  85. pass
  86. def run(self, tokens):
  87. """Runs the current list of functions that make up the pipeline against
  88. the passed tokens."""
  89. for fn in self._stack:
  90. results = []
  91. for i, token in enumerate(tokens):
  92. # JS ignores additional arguments to the functions but we
  93. # force pipeline functions to declare (token, i, tokens)
  94. # or *args
  95. result = fn(token, i, tokens)
  96. if not result:
  97. continue
  98. if isinstance(result, (list, tuple)): # simulate Array.concat
  99. results.extend(result)
  100. else:
  101. results.append(result)
  102. tokens = results
  103. return tokens
  104. def run_string(self, string, metadata=None):
  105. """Convenience method for passing a string through a pipeline and
  106. getting strings out. This method takes care of wrapping the passed
  107. string in a token and mapping the resulting tokens back to strings."""
  108. token = Token(string, metadata)
  109. return [str(tkn) for tkn in self.run([token])]
  110. def reset(self):
  111. self._stack = []
  112. def serialize(self):
  113. return [fn.label for fn in self._stack]