vector.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import unicode_literals, division
  2. from math import sqrt
  3. from lunr.exceptions import BaseLunrException
  4. class Vector:
  5. """A vector is used to construct the vector space of documents and queries.
  6. These vectors support operations to determine the similarity between two
  7. documents or a document and a query.
  8. Normally no parameters are required for initializing a vector, but in the
  9. case of loading a previously dumped vector the raw elements can be provided
  10. to the constructor.
  11. For performance reasons vectors are implemented with a flat array, where an
  12. elements index is immediately followed by its value.
  13. E.g. [index, value, index, value].
  14. TODO: consider implemetation as 2-tuples.
  15. This allows the underlying array to be as sparse as possible and still
  16. offer decent performance when being used for vector calculations.
  17. """
  18. def __init__(self, elements=None):
  19. self._magnitude = 0
  20. self.elements = elements or []
  21. def __repr__(self):
  22. return "<Vector magnitude={}>".format(self.magnitude)
  23. def __iter__(self):
  24. return iter(self.elements)
  25. def position_for_index(self, index):
  26. """Calculates the position within the vector to insert a given index.
  27. This is used internally by insert and upsert. If there are duplicate
  28. indexes then the position is returned as if the value for that index
  29. were to be updated, but it is the callers responsibility to check
  30. whether there is a duplicate at that index
  31. """
  32. if not self.elements:
  33. return 0
  34. start = 0
  35. end = int(len(self.elements) / 2)
  36. slice_length = end - start
  37. pivot_point = int(slice_length / 2)
  38. pivot_index = self.elements[pivot_point * 2]
  39. while slice_length > 1:
  40. if pivot_index < index:
  41. start = pivot_point
  42. elif pivot_index > index:
  43. end = pivot_point
  44. else:
  45. break
  46. slice_length = end - start
  47. pivot_point = start + int(slice_length / 2)
  48. pivot_index = self.elements[pivot_point * 2]
  49. if pivot_index == index:
  50. return pivot_point * 2
  51. elif pivot_index > index:
  52. return pivot_point * 2
  53. else:
  54. return (pivot_point + 1) * 2
  55. def insert(self, insert_index, val):
  56. """Inserts an element at an index within the vector.
  57. Does not allow duplicates, will throw an error if there is already an
  58. entry for this index.
  59. """
  60. def prevent_duplicates(index, val):
  61. raise BaseLunrException("Duplicate index")
  62. self.upsert(insert_index, val, prevent_duplicates)
  63. def upsert(self, insert_index, val, fn=None):
  64. """Inserts or updates an existing index within the vector.
  65. Args:
  66. - insert_index (int): The index at which the element should be
  67. inserted.
  68. - val (int|float): The value to be inserted into the vector.
  69. - fn (callable, optional): An optional callable taking two
  70. arguments, the current value and the passed value to generate
  71. the final inserted value at the position in case of collision.
  72. """
  73. fn = fn or (lambda current, passed: passed)
  74. self._magnitude = 0
  75. position = self.position_for_index(insert_index)
  76. if position < len(self.elements) and self.elements[position] == insert_index:
  77. self.elements[position + 1] = fn(self.elements[position + 1], val)
  78. else:
  79. self.elements.insert(position, val)
  80. self.elements.insert(position, insert_index)
  81. def to_list(self):
  82. """Converts the vector to an array of the elements within the vector"""
  83. output = []
  84. for i in range(1, len(self.elements), 2):
  85. output.append(self.elements[i])
  86. return output
  87. def serialize(self):
  88. # TODO: the JS version forces rounding on the elements upon insertion
  89. # to ensure symmetry upon serialization
  90. return [round(element, 3) for element in self.elements]
  91. @property
  92. def magnitude(self):
  93. if not self._magnitude:
  94. sum_of_squares = 0
  95. for i in range(1, len(self.elements), 2):
  96. value = self.elements[i]
  97. sum_of_squares += value * value
  98. self._magnitude = sqrt(sum_of_squares)
  99. return self._magnitude
  100. def dot(self, other):
  101. """Calculates the dot product of this vector and another vector."""
  102. dot_product = 0
  103. a = self.elements
  104. b = other.elements
  105. a_len = len(a)
  106. b_len = len(b)
  107. i = j = 0
  108. while i < a_len and j < b_len:
  109. a_val = a[i]
  110. b_val = b[j]
  111. if a_val < b_val:
  112. i += 2
  113. elif a_val > b_val:
  114. j += 2
  115. else:
  116. dot_product += a[i + 1] * b[j + 1]
  117. i += 2
  118. j += 2
  119. return dot_product
  120. def similarity(self, other):
  121. """Calculates the cosine similarity between this vector and another
  122. vector."""
  123. if self.magnitude == 0 or other.magnitude == 0:
  124. return 0
  125. return self.dot(other) / self.magnitude