decisiontree.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Natural Language Toolkit: Decision Tree Classifiers
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. A classifier model that decides which label to assign to a token on
  9. the basis of a tree structure, where branches correspond to conditions
  10. on feature values, and leaves correspond to label assignments.
  11. """
  12. from collections import defaultdict
  13. from nltk.probability import FreqDist, MLEProbDist, entropy
  14. from nltk.classify.api import ClassifierI
  15. class DecisionTreeClassifier(ClassifierI):
  16. def __init__(self, label, feature_name=None, decisions=None, default=None):
  17. """
  18. :param label: The most likely label for tokens that reach
  19. this node in the decision tree. If this decision tree
  20. has no children, then this label will be assigned to
  21. any token that reaches this decision tree.
  22. :param feature_name: The name of the feature that this
  23. decision tree selects for.
  24. :param decisions: A dictionary mapping from feature values
  25. for the feature identified by ``feature_name`` to
  26. child decision trees.
  27. :param default: The child that will be used if the value of
  28. feature ``feature_name`` does not match any of the keys in
  29. ``decisions``. This is used when constructing binary
  30. decision trees.
  31. """
  32. self._label = label
  33. self._fname = feature_name
  34. self._decisions = decisions
  35. self._default = default
  36. def labels(self):
  37. labels = [self._label]
  38. if self._decisions is not None:
  39. for dt in self._decisions.values():
  40. labels.extend(dt.labels())
  41. if self._default is not None:
  42. labels.extend(self._default.labels())
  43. return list(set(labels))
  44. def classify(self, featureset):
  45. # Decision leaf:
  46. if self._fname is None:
  47. return self._label
  48. # Decision tree:
  49. fval = featureset.get(self._fname)
  50. if fval in self._decisions:
  51. return self._decisions[fval].classify(featureset)
  52. elif self._default is not None:
  53. return self._default.classify(featureset)
  54. else:
  55. return self._label
  56. def error(self, labeled_featuresets):
  57. errors = 0
  58. for featureset, label in labeled_featuresets:
  59. if self.classify(featureset) != label:
  60. errors += 1
  61. return errors / len(labeled_featuresets)
  62. def pretty_format(self, width=70, prefix="", depth=4):
  63. """
  64. Return a string containing a pretty-printed version of this
  65. decision tree. Each line in this string corresponds to a
  66. single decision tree node or leaf, and indentation is used to
  67. display the structure of the decision tree.
  68. """
  69. # [xx] display default!!
  70. if self._fname is None:
  71. n = width - len(prefix) - 15
  72. return '{0}{1} {2}\n'.format(prefix, '.' * n, self._label)
  73. s = ''
  74. for i, (fval, result) in enumerate(sorted(self._decisions.items(),
  75. key=lambda item:
  76. (item[0] in [None, False, True], str(item[0]).lower())
  77. )
  78. ):
  79. hdr = '{0}{1}={2}? '.format(prefix, self._fname, fval)
  80. n = width - 15 - len(hdr)
  81. s += "{0}{1} {2}\n".format(hdr, "." * (n), result._label)
  82. if result._fname is not None and depth > 1:
  83. s += result.pretty_format(width, prefix + " ", depth - 1)
  84. if self._default is not None:
  85. n = width - len(prefix) - 21
  86. s += "{0}else: {1} {2}\n".format(prefix, "." * n, self._default._label)
  87. if self._default._fname is not None and depth > 1:
  88. s += self._default.pretty_format(width, prefix + " ", depth - 1)
  89. return s
  90. def pseudocode(self, prefix="", depth=4):
  91. """
  92. Return a string representation of this decision tree that
  93. expresses the decisions it makes as a nested set of pseudocode
  94. if statements.
  95. """
  96. if self._fname is None:
  97. return "{0}return {1!r}\n".format(prefix, self._label)
  98. s = ''
  99. for (fval, result) in sorted(self._decisions.items(),
  100. key=lambda item:
  101. (item[0] in [None, False, True], str(item[0]).lower())
  102. ):
  103. s += '{0}if {1} == {2!r}: '.format(prefix, self._fname, fval)
  104. if result._fname is not None and depth > 1:
  105. s += "\n" + result.pseudocode(prefix + " ", depth - 1)
  106. else:
  107. s += "return {0!r}\n".format(result._label)
  108. if self._default is not None:
  109. if len(self._decisions) == 1:
  110. s += "{0}if {1} != {2!r}: ".format(
  111. prefix, self._fname, list(self._decisions.keys())[0]
  112. )
  113. else:
  114. s += "{0}else: ".format(prefix)
  115. if self._default._fname is not None and depth > 1:
  116. s += "\n" + self._default.pseudocode(prefix + " ", depth - 1)
  117. else:
  118. s += "return {0!r}\n".format(self._default._label)
  119. return s
  120. def __str__(self):
  121. return self.pretty_format()
  122. @staticmethod
  123. def train(
  124. labeled_featuresets,
  125. entropy_cutoff=0.05,
  126. depth_cutoff=100,
  127. support_cutoff=10,
  128. binary=False,
  129. feature_values=None,
  130. verbose=False,
  131. ):
  132. """
  133. :param binary: If true, then treat all feature/value pairs as
  134. individual binary features, rather than using a single n-way
  135. branch for each feature.
  136. """
  137. # Collect a list of all feature names.
  138. feature_names = set()
  139. for featureset, label in labeled_featuresets:
  140. for fname in featureset:
  141. feature_names.add(fname)
  142. # Collect a list of the values each feature can take.
  143. if feature_values is None and binary:
  144. feature_values = defaultdict(set)
  145. for featureset, label in labeled_featuresets:
  146. for fname, fval in featureset.items():
  147. feature_values[fname].add(fval)
  148. # Start with a stump.
  149. if not binary:
  150. tree = DecisionTreeClassifier.best_stump(
  151. feature_names, labeled_featuresets, verbose
  152. )
  153. else:
  154. tree = DecisionTreeClassifier.best_binary_stump(
  155. feature_names, labeled_featuresets, feature_values, verbose
  156. )
  157. # Refine the stump.
  158. tree.refine(
  159. labeled_featuresets,
  160. entropy_cutoff,
  161. depth_cutoff - 1,
  162. support_cutoff,
  163. binary,
  164. feature_values,
  165. verbose,
  166. )
  167. # Return it
  168. return tree
  169. @staticmethod
  170. def leaf(labeled_featuresets):
  171. label = FreqDist(label for (featureset, label) in labeled_featuresets).max()
  172. return DecisionTreeClassifier(label)
  173. @staticmethod
  174. def stump(feature_name, labeled_featuresets):
  175. label = FreqDist(label for (featureset, label) in labeled_featuresets).max()
  176. # Find the best label for each value.
  177. freqs = defaultdict(FreqDist) # freq(label|value)
  178. for featureset, label in labeled_featuresets:
  179. feature_value = featureset.get(feature_name)
  180. freqs[feature_value][label] += 1
  181. decisions = dict(
  182. (val, DecisionTreeClassifier(freqs[val].max())) for val in freqs
  183. )
  184. return DecisionTreeClassifier(label, feature_name, decisions)
  185. def refine(
  186. self,
  187. labeled_featuresets,
  188. entropy_cutoff,
  189. depth_cutoff,
  190. support_cutoff,
  191. binary=False,
  192. feature_values=None,
  193. verbose=False,
  194. ):
  195. if len(labeled_featuresets) <= support_cutoff:
  196. return
  197. if self._fname is None:
  198. return
  199. if depth_cutoff <= 0:
  200. return
  201. for fval in self._decisions:
  202. fval_featuresets = [
  203. (featureset, label)
  204. for (featureset, label) in labeled_featuresets
  205. if featureset.get(self._fname) == fval
  206. ]
  207. label_freqs = FreqDist(label for (featureset, label) in fval_featuresets)
  208. if entropy(MLEProbDist(label_freqs)) > entropy_cutoff:
  209. self._decisions[fval] = DecisionTreeClassifier.train(
  210. fval_featuresets,
  211. entropy_cutoff,
  212. depth_cutoff,
  213. support_cutoff,
  214. binary,
  215. feature_values,
  216. verbose,
  217. )
  218. if self._default is not None:
  219. default_featuresets = [
  220. (featureset, label)
  221. for (featureset, label) in labeled_featuresets
  222. if featureset.get(self._fname) not in self._decisions
  223. ]
  224. label_freqs = FreqDist(label for (featureset, label) in default_featuresets)
  225. if entropy(MLEProbDist(label_freqs)) > entropy_cutoff:
  226. self._default = DecisionTreeClassifier.train(
  227. default_featuresets,
  228. entropy_cutoff,
  229. depth_cutoff,
  230. support_cutoff,
  231. binary,
  232. feature_values,
  233. verbose,
  234. )
  235. @staticmethod
  236. def best_stump(feature_names, labeled_featuresets, verbose=False):
  237. best_stump = DecisionTreeClassifier.leaf(labeled_featuresets)
  238. best_error = best_stump.error(labeled_featuresets)
  239. for fname in feature_names:
  240. stump = DecisionTreeClassifier.stump(fname, labeled_featuresets)
  241. stump_error = stump.error(labeled_featuresets)
  242. if stump_error < best_error:
  243. best_error = stump_error
  244. best_stump = stump
  245. if verbose:
  246. print(
  247. (
  248. "best stump for {:6d} toks uses {:20} err={:6.4f}".format(
  249. len(labeled_featuresets), best_stump._fname, best_error
  250. )
  251. )
  252. )
  253. return best_stump
  254. @staticmethod
  255. def binary_stump(feature_name, feature_value, labeled_featuresets):
  256. label = FreqDist(label for (featureset, label) in labeled_featuresets).max()
  257. # Find the best label for each value.
  258. pos_fdist = FreqDist()
  259. neg_fdist = FreqDist()
  260. for featureset, label in labeled_featuresets:
  261. if featureset.get(feature_name) == feature_value:
  262. pos_fdist[label] += 1
  263. else:
  264. neg_fdist[label] += 1
  265. decisions = {}
  266. default = label
  267. # But hopefully we have observations!
  268. if pos_fdist.N() > 0:
  269. decisions = {feature_value: DecisionTreeClassifier(pos_fdist.max())}
  270. if neg_fdist.N() > 0:
  271. default = DecisionTreeClassifier(neg_fdist.max())
  272. return DecisionTreeClassifier(label, feature_name, decisions, default)
  273. @staticmethod
  274. def best_binary_stump(
  275. feature_names, labeled_featuresets, feature_values, verbose=False
  276. ):
  277. best_stump = DecisionTreeClassifier.leaf(labeled_featuresets)
  278. best_error = best_stump.error(labeled_featuresets)
  279. for fname in feature_names:
  280. for fval in feature_values[fname]:
  281. stump = DecisionTreeClassifier.binary_stump(
  282. fname, fval, labeled_featuresets
  283. )
  284. stump_error = stump.error(labeled_featuresets)
  285. if stump_error < best_error:
  286. best_error = stump_error
  287. best_stump = stump
  288. if verbose:
  289. if best_stump._decisions:
  290. descr = "{0}={1}".format(
  291. best_stump._fname, list(best_stump._decisions.keys())[0]
  292. )
  293. else:
  294. descr = "(default)"
  295. print(
  296. (
  297. "best stump for {:6d} toks uses {:20} err={:6.4f}".format(
  298. len(labeled_featuresets), descr, best_error
  299. )
  300. )
  301. )
  302. return best_stump
  303. ##//////////////////////////////////////////////////////
  304. ## Demo
  305. ##//////////////////////////////////////////////////////
  306. def f(x):
  307. return DecisionTreeClassifier.train(x, binary=True, verbose=True)
  308. def demo():
  309. from nltk.classify.util import names_demo, binary_names_demo_features
  310. classifier = names_demo(
  311. f, binary_names_demo_features # DecisionTreeClassifier.train,
  312. )
  313. print(classifier.pretty_format(depth=7))
  314. print(classifier.pseudocode(depth=7))
  315. if __name__ == "__main__":
  316. demo()