from __future__ import unicode_literals, division from math import sqrt from lunr.exceptions import BaseLunrException class Vector: """A vector is used to construct the vector space of documents and queries. These vectors support operations to determine the similarity between two documents or a document and a query. Normally no parameters are required for initializing a vector, but in the case of loading a previously dumped vector the raw elements can be provided to the constructor. For performance reasons vectors are implemented with a flat array, where an elements index is immediately followed by its value. E.g. [index, value, index, value]. TODO: consider implemetation as 2-tuples. This allows the underlying array to be as sparse as possible and still offer decent performance when being used for vector calculations. """ def __init__(self, elements=None): self._magnitude = 0 self.elements = elements or [] def __repr__(self): return "".format(self.magnitude) def __iter__(self): return iter(self.elements) def position_for_index(self, index): """Calculates the position within the vector to insert a given index. This is used internally by insert and upsert. If there are duplicate indexes then the position is returned as if the value for that index were to be updated, but it is the callers responsibility to check whether there is a duplicate at that index """ if not self.elements: return 0 start = 0 end = int(len(self.elements) / 2) slice_length = end - start pivot_point = int(slice_length / 2) pivot_index = self.elements[pivot_point * 2] while slice_length > 1: if pivot_index < index: start = pivot_point elif pivot_index > index: end = pivot_point else: break slice_length = end - start pivot_point = start + int(slice_length / 2) pivot_index = self.elements[pivot_point * 2] if pivot_index == index: return pivot_point * 2 elif pivot_index > index: return pivot_point * 2 else: return (pivot_point + 1) * 2 def insert(self, insert_index, val): """Inserts an element at an index within the vector. Does not allow duplicates, will throw an error if there is already an entry for this index. """ def prevent_duplicates(index, val): raise BaseLunrException("Duplicate index") self.upsert(insert_index, val, prevent_duplicates) def upsert(self, insert_index, val, fn=None): """Inserts or updates an existing index within the vector. Args: - insert_index (int): The index at which the element should be inserted. - val (int|float): The value to be inserted into the vector. - fn (callable, optional): An optional callable taking two arguments, the current value and the passed value to generate the final inserted value at the position in case of collision. """ fn = fn or (lambda current, passed: passed) self._magnitude = 0 position = self.position_for_index(insert_index) if position < len(self.elements) and self.elements[position] == insert_index: self.elements[position + 1] = fn(self.elements[position + 1], val) else: self.elements.insert(position, val) self.elements.insert(position, insert_index) def to_list(self): """Converts the vector to an array of the elements within the vector""" output = [] for i in range(1, len(self.elements), 2): output.append(self.elements[i]) return output def serialize(self): # TODO: the JS version forces rounding on the elements upon insertion # to ensure symmetry upon serialization return [round(element, 3) for element in self.elements] @property def magnitude(self): if not self._magnitude: sum_of_squares = 0 for i in range(1, len(self.elements), 2): value = self.elements[i] sum_of_squares += value * value self._magnitude = sqrt(sum_of_squares) return self._magnitude def dot(self, other): """Calculates the dot product of this vector and another vector.""" dot_product = 0 a = self.elements b = other.elements a_len = len(a) b_len = len(b) i = j = 0 while i < a_len and j < b_len: a_val = a[i] b_val = b[j] if a_val < b_val: i += 2 elif a_val > b_val: j += 2 else: dot_product += a[i + 1] * b[j + 1] i += 2 j += 2 return dot_product def similarity(self, other): """Calculates the cosine similarity between this vector and another vector.""" if self.magnitude == 0 or other.magnitude == 0: return 0 return self.dot(other) / self.magnitude