index.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. from __future__ import unicode_literals
  2. from collections import defaultdict
  3. import json
  4. import logging
  5. from builtins import str, dict # noqa
  6. from past.builtins import basestring
  7. from lunr.exceptions import BaseLunrException
  8. from lunr.field_ref import FieldRef
  9. from lunr.match_data import MatchData
  10. from lunr.token_set import TokenSet
  11. from lunr.token_set_builder import TokenSetBuilder
  12. from lunr.pipeline import Pipeline
  13. from lunr.query import Query, QueryPresence
  14. from lunr.query_parser import QueryParser
  15. from lunr.utils import CompleteSet
  16. from lunr.vector import Vector
  17. logger = logging.getLogger(__name__)
  18. class Index:
  19. """An index contains the built index of all documents and provides a query
  20. interface to the index.
  21. Usually instances of lunr.Index will not be created using this
  22. constructor, instead lunr.Builder should be used to construct new
  23. indexes, or lunr.Index.load should be used to load previously built and
  24. serialized indexes.
  25. """
  26. def __init__(self, inverted_index, field_vectors, token_set, fields, pipeline):
  27. self.inverted_index = inverted_index
  28. self.field_vectors = field_vectors
  29. self.token_set = token_set
  30. self.fields = fields
  31. self.pipeline = pipeline
  32. def __eq__(self, other):
  33. # TODO: extend equality to other attributes
  34. return (
  35. self.inverted_index == other.inverted_index and self.fields == other.fields
  36. )
  37. def search(self, query_string):
  38. """Performs a search against the index using lunr query syntax.
  39. Results will be returned sorted by their score, the most relevant
  40. results will be returned first.
  41. For more programmatic querying use `lunr.Index.query`.
  42. Args:
  43. query_string (str): A string to parse into a Query.
  44. Returns:
  45. dict: Results of executing the query.
  46. """
  47. query = self.create_query()
  48. # TODO: should QueryParser be a method of query? should it return one?
  49. parser = QueryParser(query_string, query)
  50. parser.parse()
  51. return self.query(query)
  52. def create_query(self, fields=None):
  53. """Convenience method to create a Query with the Index's fields.
  54. Args:
  55. fields (iterable, optional): The fields to include in the Query,
  56. defaults to the Index's `all_fields`.
  57. Returns:
  58. Query: With the specified fields or all the fields in the Index.
  59. """
  60. if fields is None:
  61. return Query(self.fields)
  62. non_contained_fields = set(fields) - set(self.fields)
  63. if non_contained_fields:
  64. raise BaseLunrException(
  65. "Fields {} are not part of the index", non_contained_fields
  66. )
  67. return Query(fields)
  68. def query(self, query=None, callback=None):
  69. """Performs a query against the index using the passed lunr.Query
  70. object.
  71. If performing programmatic queries against the index, this method is
  72. preferred over `lunr.Index.search` so as to avoid the additional query
  73. parsing overhead.
  74. Args:
  75. query (lunr.Query): A configured Query to perform the search
  76. against, use `create_query` to get a preconfigured object
  77. or use `callback` for convenience.
  78. callback (callable): An optional function taking a single Query
  79. object result of `create_query` for further configuration.
  80. """
  81. if query is None:
  82. query = self.create_query()
  83. if callback is not None:
  84. callback(query)
  85. if len(query.clauses) == 0:
  86. logger.warning(
  87. "Attempting a query with no clauses. Please add clauses by "
  88. "either using the `callback` argument or using `create_query` "
  89. "to create a preconfigured Query, manually adding clauses and "
  90. "passing it as the `query` argument."
  91. )
  92. return []
  93. # for each query clause
  94. # * process terms
  95. # * expand terms from token set
  96. # * find matching documents and metadata
  97. # * get document vectors
  98. # * score documents
  99. matching_fields = {}
  100. query_vectors = {field: Vector() for field in self.fields}
  101. term_field_cache = {}
  102. required_matches = {}
  103. prohibited_matches = defaultdict(set)
  104. for clause in query.clauses:
  105. # Unless the pipeline has been disabled for this term, which is
  106. # the case for terms with wildcards, we need to pass the clause
  107. # term through the search pipeline. A pipeline returns an array
  108. # of processed terms. Pipeline functions may expand the passed
  109. # term, which means we may end up performing multiple index lookups
  110. # for a single query term.
  111. if clause.use_pipeline:
  112. terms = self.pipeline.run_string(clause.term, {"fields": clause.fields})
  113. else:
  114. terms = [clause.term]
  115. clause_matches = CompleteSet()
  116. for term in terms:
  117. # Each term returned from the pipeline needs to use the same
  118. # query clause object, e.g. the same boost and or edit distance
  119. # The simplest way to do this is to re-use the clause object
  120. # but mutate its term property.
  121. clause.term = term
  122. # From the term in the clause we create a token set which will
  123. # then be used to intersect the indexes token set to get a list
  124. # of terms to lookup in the inverted index
  125. term_token_set = TokenSet.from_clause(clause)
  126. expanded_terms = self.token_set.intersect(term_token_set).to_list()
  127. # If a term marked as required does not exist in the TokenSet
  128. # it is impossible for the search to return any matches.
  129. # We set all the field-scoped required matches set to empty
  130. # and stop examining further clauses
  131. if (
  132. len(expanded_terms) == 0
  133. and clause.presence == QueryPresence.REQUIRED
  134. ):
  135. for field in clause.fields:
  136. required_matches[field] = CompleteSet()
  137. break
  138. for expanded_term in expanded_terms:
  139. posting = self.inverted_index[expanded_term]
  140. term_index = posting["_index"]
  141. for field in clause.fields:
  142. # For each field that this query term is scoped by
  143. # (by default all fields are in scope) we need to get
  144. # all the document refs that have this term in that
  145. # field.
  146. #
  147. # The posting is the entry in the invertedIndex for the
  148. # matching term from above.
  149. field_posting = posting[field]
  150. matching_document_refs = field_posting.keys()
  151. term_field = expanded_term + "/" + field
  152. matching_documents_set = set(matching_document_refs)
  153. # If the presence of this term is required, ensure that
  154. # the matching documents are added to the set of
  155. # required matches for this clause.
  156. if clause.presence == QueryPresence.REQUIRED:
  157. clause_matches = clause_matches.union(
  158. matching_documents_set
  159. )
  160. if field not in required_matches:
  161. required_matches[field] = CompleteSet()
  162. # If the presence of this term is prohibited,
  163. # ensure that the matching documents are added to the
  164. # set of prohibited matches for this field, creating
  165. # that set if it does not exist yet.
  166. elif clause.presence == QueryPresence.PROHIBITED:
  167. prohibited_matches[field] = prohibited_matches[field].union(
  168. matching_documents_set
  169. )
  170. # prohibited matches should not be part of the
  171. # query vector used for similarity scoring and no
  172. # metadata should be extracted so we continue
  173. # to the next field
  174. continue
  175. # The query field vector is populated using the
  176. # term_index found for the term an a unit value with
  177. # the appropriate boost
  178. # Using upsert because there could already be an entry
  179. # in the vector for the term we are working with.
  180. # In that case we just add the scores together.
  181. query_vectors[field].upsert(
  182. term_index, clause.boost, lambda a, b: a + b
  183. )
  184. # If we've already seen this term, field combo then
  185. # we've already collected the matching documents and
  186. # metadata, no need to go through all that again
  187. if term_field in term_field_cache:
  188. continue
  189. for matching_document_ref in matching_document_refs:
  190. # All metadata for this term/field/document triple
  191. # are then extracted and collected into an instance
  192. # of lunr.MatchData ready to be returned in the
  193. # query results
  194. matching_field_ref = FieldRef(matching_document_ref, field)
  195. metadata = field_posting[str(matching_document_ref)]
  196. if str(matching_field_ref) not in matching_fields:
  197. matching_fields[str(matching_field_ref)] = MatchData(
  198. expanded_term, field, metadata
  199. )
  200. else:
  201. matching_fields[str(matching_field_ref)].add(
  202. expanded_term, field, metadata
  203. )
  204. term_field_cache[term_field] = True
  205. # if the presence was required we need to update the required
  206. # matches field sets, we do this after all fields for the term
  207. # have collected their matches because the clause terms presence
  208. # is required in _any_ of the fields, not _all_ of the fields
  209. if clause.presence == QueryPresence.REQUIRED:
  210. for field in clause.fields:
  211. required_matches[field] = required_matches[field].intersection(
  212. clause_matches
  213. )
  214. # We need to combine the field scoped required and prohibited
  215. # matching documents inot a global set of required and prohibited
  216. # matches
  217. all_required_matches = CompleteSet()
  218. all_prohibited_matches = set()
  219. for field in self.fields:
  220. if field in required_matches:
  221. all_required_matches = all_required_matches.intersection(
  222. required_matches[field]
  223. )
  224. if field in prohibited_matches:
  225. all_prohibited_matches = all_prohibited_matches.union(
  226. prohibited_matches[field]
  227. )
  228. matching_field_refs = matching_fields.keys()
  229. results = []
  230. matches = {}
  231. # If the query is negated (only contains prohibited terms)
  232. # we need to get _all_ field_refs currently existing in the index.
  233. # This to avoid any costs of getting all field regs unnecessarily
  234. # Additionally, blank match data must be created to correctly populate
  235. # the results
  236. if query.is_negated():
  237. matching_field_refs = list(self.field_vectors.keys())
  238. for matching_field_ref in matching_field_refs:
  239. field_ref = FieldRef.from_string(matching_field_ref)
  240. matching_fields[matching_field_ref] = MatchData()
  241. for matching_field_ref in matching_field_refs:
  242. # Currently we have document fields that match the query, but we
  243. # need to return documents. The matchData and scores are combined
  244. # from multiple fields belonging to the same document.
  245. #
  246. # Scores are calculated by field, using the query vectors created
  247. # above, and combined into a final document score using addition.
  248. field_ref = FieldRef.from_string(matching_field_ref)
  249. doc_ref = field_ref.doc_ref
  250. if doc_ref not in all_required_matches or doc_ref in all_prohibited_matches:
  251. continue
  252. field_vector = self.field_vectors[matching_field_ref]
  253. score = query_vectors[field_ref.field_name].similarity(field_vector)
  254. try:
  255. doc_match = matches[doc_ref]
  256. doc_match["score"] += score
  257. doc_match["match_data"].combine(matching_fields[matching_field_ref])
  258. except KeyError:
  259. match = {
  260. "ref": doc_ref,
  261. "score": score,
  262. "match_data": matching_fields[matching_field_ref],
  263. }
  264. matches[doc_ref] = match
  265. results.append(match)
  266. return sorted(results, key=lambda a: a["score"], reverse=True)
  267. def serialize(self):
  268. from lunr import __TARGET_JS_VERSION__
  269. inverted_index = [
  270. [term, self.inverted_index[term]] for term in sorted(self.inverted_index)
  271. ]
  272. field_vectors = [
  273. [ref, vector.serialize()] for ref, vector in self.field_vectors.items()
  274. ]
  275. # CamelCased keys for compatibility with JS version
  276. return {
  277. "version": __TARGET_JS_VERSION__,
  278. "fields": self.fields,
  279. "fieldVectors": field_vectors,
  280. "invertedIndex": inverted_index,
  281. "pipeline": self.pipeline.serialize(),
  282. }
  283. @classmethod
  284. def load(cls, serialized_index):
  285. """Load a serialized index"""
  286. from lunr import __TARGET_JS_VERSION__
  287. if isinstance(serialized_index, basestring):
  288. serialized_index = json.loads(serialized_index)
  289. if serialized_index["version"] != __TARGET_JS_VERSION__:
  290. logger.warning(
  291. "Version mismatch when loading serialized index. "
  292. "Current version of lunr {} does not match that of serialized "
  293. "index {}".format(__TARGET_JS_VERSION__, serialized_index["version"])
  294. )
  295. field_vectors = {
  296. ref: Vector(elements) for ref, elements in serialized_index["fieldVectors"]
  297. }
  298. tokenset_builder = TokenSetBuilder()
  299. inverted_index = {}
  300. for term, posting in serialized_index["invertedIndex"]:
  301. tokenset_builder.insert(term)
  302. inverted_index[term] = posting
  303. tokenset_builder.finish()
  304. return Index(
  305. fields=serialized_index["fields"],
  306. field_vectors=field_vectors,
  307. inverted_index=inverted_index,
  308. token_set=tokenset_builder.root,
  309. pipeline=Pipeline.load(serialized_index["pipeline"]),
  310. )