api.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Natural Language Toolkit: Clusterer Interfaces
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Trevor Cohn <tacohn@cs.mu.oz.au>
  5. # Porting: Steven Bird <stevenbird1@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. from abc import ABCMeta, abstractmethod
  9. from nltk.probability import DictionaryProbDist
  10. class ClusterI(metaclass=ABCMeta):
  11. """
  12. Interface covering basic clustering functionality.
  13. """
  14. @abstractmethod
  15. def cluster(self, vectors, assign_clusters=False):
  16. """
  17. Assigns the vectors to clusters, learning the clustering parameters
  18. from the data. Returns a cluster identifier for each vector.
  19. """
  20. @abstractmethod
  21. def classify(self, token):
  22. """
  23. Classifies the token into a cluster, setting the token's CLUSTER
  24. parameter to that cluster identifier.
  25. """
  26. def likelihood(self, vector, label):
  27. """
  28. Returns the likelihood (a float) of the token having the
  29. corresponding cluster.
  30. """
  31. if self.classify(vector) == label:
  32. return 1.0
  33. else:
  34. return 0.0
  35. def classification_probdist(self, vector):
  36. """
  37. Classifies the token into a cluster, returning
  38. a probability distribution over the cluster identifiers.
  39. """
  40. likelihoods = {}
  41. sum = 0.0
  42. for cluster in self.cluster_names():
  43. likelihoods[cluster] = self.likelihood(vector, cluster)
  44. sum += likelihoods[cluster]
  45. for cluster in self.cluster_names():
  46. likelihoods[cluster] /= sum
  47. return DictionaryProbDist(likelihoods)
  48. @abstractmethod
  49. def num_clusters(self):
  50. """
  51. Returns the number of clusters.
  52. """
  53. def cluster_names(self):
  54. """
  55. Returns the names of the clusters.
  56. :rtype: list
  57. """
  58. return list(range(self.num_clusters()))
  59. def cluster_name(self, index):
  60. """
  61. Returns the names of the cluster at index.
  62. """
  63. return index