demo.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: Transformation-based learning
  3. #
  4. # Copyright (C) 2001-2020 NLTK Project
  5. # Author: Marcus Uneson <marcus.uneson@gmail.com>
  6. # based on previous (nltk2) version by
  7. # Christopher Maloof, Edward Loper, Steven Bird
  8. # URL: <http://nltk.org/>
  9. # For license information, see LICENSE.TXT
  10. import os
  11. import pickle
  12. import random
  13. import time
  14. from nltk.corpus import treebank
  15. from nltk.tbl import error_list, Template
  16. from nltk.tag.brill import Word, Pos
  17. from nltk.tag import BrillTaggerTrainer, RegexpTagger, UnigramTagger
  18. def demo():
  19. """
  20. Run a demo with defaults. See source comments for details,
  21. or docstrings of any of the more specific demo_* functions.
  22. """
  23. postag()
  24. def demo_repr_rule_format():
  25. """
  26. Exemplify repr(Rule) (see also str(Rule) and Rule.format("verbose"))
  27. """
  28. postag(ruleformat="repr")
  29. def demo_str_rule_format():
  30. """
  31. Exemplify repr(Rule) (see also str(Rule) and Rule.format("verbose"))
  32. """
  33. postag(ruleformat="str")
  34. def demo_verbose_rule_format():
  35. """
  36. Exemplify Rule.format("verbose")
  37. """
  38. postag(ruleformat="verbose")
  39. def demo_multiposition_feature():
  40. """
  41. The feature/s of a template takes a list of positions
  42. relative to the current word where the feature should be
  43. looked for, conceptually joined by logical OR. For instance,
  44. Pos([-1, 1]), given a value V, will hold whenever V is found
  45. one step to the left and/or one step to the right.
  46. For contiguous ranges, a 2-arg form giving inclusive end
  47. points can also be used: Pos(-3, -1) is the same as the arg
  48. below.
  49. """
  50. postag(templates=[Template(Pos([-3, -2, -1]))])
  51. def demo_multifeature_template():
  52. """
  53. Templates can have more than a single feature.
  54. """
  55. postag(templates=[Template(Word([0]), Pos([-2, -1]))])
  56. def demo_template_statistics():
  57. """
  58. Show aggregate statistics per template. Little used templates are
  59. candidates for deletion, much used templates may possibly be refined.
  60. Deleting unused templates is mostly about saving time and/or space:
  61. training is basically O(T) in the number of templates T
  62. (also in terms of memory usage, which often will be the limiting factor).
  63. """
  64. postag(incremental_stats=True, template_stats=True)
  65. def demo_generated_templates():
  66. """
  67. Template.expand and Feature.expand are class methods facilitating
  68. generating large amounts of templates. See their documentation for
  69. details.
  70. Note: training with 500 templates can easily fill all available
  71. even on relatively small corpora
  72. """
  73. wordtpls = Word.expand([-1, 0, 1], [1, 2], excludezero=False)
  74. tagtpls = Pos.expand([-2, -1, 0, 1], [1, 2], excludezero=True)
  75. templates = list(Template.expand([wordtpls, tagtpls], combinations=(1, 3)))
  76. print(
  77. "Generated {0} templates for transformation-based learning".format(
  78. len(templates)
  79. )
  80. )
  81. postag(templates=templates, incremental_stats=True, template_stats=True)
  82. def demo_learning_curve():
  83. """
  84. Plot a learning curve -- the contribution on tagging accuracy of
  85. the individual rules.
  86. Note: requires matplotlib
  87. """
  88. postag(
  89. incremental_stats=True,
  90. separate_baseline_data=True,
  91. learning_curve_output="learningcurve.png",
  92. )
  93. def demo_error_analysis():
  94. """
  95. Writes a file with context for each erroneous word after tagging testing data
  96. """
  97. postag(error_output="errors.txt")
  98. def demo_serialize_tagger():
  99. """
  100. Serializes the learned tagger to a file in pickle format; reloads it
  101. and validates the process.
  102. """
  103. postag(serialize_output="tagger.pcl")
  104. def demo_high_accuracy_rules():
  105. """
  106. Discard rules with low accuracy. This may hurt performance a bit,
  107. but will often produce rules which are more interesting read to a human.
  108. """
  109. postag(num_sents=3000, min_acc=0.96, min_score=10)
  110. def postag(
  111. templates=None,
  112. tagged_data=None,
  113. num_sents=1000,
  114. max_rules=300,
  115. min_score=3,
  116. min_acc=None,
  117. train=0.8,
  118. trace=3,
  119. randomize=False,
  120. ruleformat="str",
  121. incremental_stats=False,
  122. template_stats=False,
  123. error_output=None,
  124. serialize_output=None,
  125. learning_curve_output=None,
  126. learning_curve_take=300,
  127. baseline_backoff_tagger=None,
  128. separate_baseline_data=False,
  129. cache_baseline_tagger=None,
  130. ):
  131. """
  132. Brill Tagger Demonstration
  133. :param templates: how many sentences of training and testing data to use
  134. :type templates: list of Template
  135. :param tagged_data: maximum number of rule instances to create
  136. :type tagged_data: C{int}
  137. :param num_sents: how many sentences of training and testing data to use
  138. :type num_sents: C{int}
  139. :param max_rules: maximum number of rule instances to create
  140. :type max_rules: C{int}
  141. :param min_score: the minimum score for a rule in order for it to be considered
  142. :type min_score: C{int}
  143. :param min_acc: the minimum score for a rule in order for it to be considered
  144. :type min_acc: C{float}
  145. :param train: the fraction of the the corpus to be used for training (1=all)
  146. :type train: C{float}
  147. :param trace: the level of diagnostic tracing output to produce (0-4)
  148. :type trace: C{int}
  149. :param randomize: whether the training data should be a random subset of the corpus
  150. :type randomize: C{bool}
  151. :param ruleformat: rule output format, one of "str", "repr", "verbose"
  152. :type ruleformat: C{str}
  153. :param incremental_stats: if true, will tag incrementally and collect stats for each rule (rather slow)
  154. :type incremental_stats: C{bool}
  155. :param template_stats: if true, will print per-template statistics collected in training and (optionally) testing
  156. :type template_stats: C{bool}
  157. :param error_output: the file where errors will be saved
  158. :type error_output: C{string}
  159. :param serialize_output: the file where the learned tbl tagger will be saved
  160. :type serialize_output: C{string}
  161. :param learning_curve_output: filename of plot of learning curve(s) (train and also test, if available)
  162. :type learning_curve_output: C{string}
  163. :param learning_curve_take: how many rules plotted
  164. :type learning_curve_take: C{int}
  165. :param baseline_backoff_tagger: the file where rules will be saved
  166. :type baseline_backoff_tagger: tagger
  167. :param separate_baseline_data: use a fraction of the training data exclusively for training baseline
  168. :type separate_baseline_data: C{bool}
  169. :param cache_baseline_tagger: cache baseline tagger to this file (only interesting as a temporary workaround to get
  170. deterministic output from the baseline unigram tagger between python versions)
  171. :type cache_baseline_tagger: C{string}
  172. Note on separate_baseline_data: if True, reuse training data both for baseline and rule learner. This
  173. is fast and fine for a demo, but is likely to generalize worse on unseen data.
  174. Also cannot be sensibly used for learning curves on training data (the baseline will be artificially high).
  175. """
  176. # defaults
  177. baseline_backoff_tagger = baseline_backoff_tagger or REGEXP_TAGGER
  178. if templates is None:
  179. from nltk.tag.brill import describe_template_sets, brill24
  180. # some pre-built template sets taken from typical systems or publications are
  181. # available. Print a list with describe_template_sets()
  182. # for instance:
  183. templates = brill24()
  184. (training_data, baseline_data, gold_data, testing_data) = _demo_prepare_data(
  185. tagged_data, train, num_sents, randomize, separate_baseline_data
  186. )
  187. # creating (or reloading from cache) a baseline tagger (unigram tagger)
  188. # this is just a mechanism for getting deterministic output from the baseline between
  189. # python versions
  190. if cache_baseline_tagger:
  191. if not os.path.exists(cache_baseline_tagger):
  192. baseline_tagger = UnigramTagger(
  193. baseline_data, backoff=baseline_backoff_tagger
  194. )
  195. with open(cache_baseline_tagger, "w") as print_rules:
  196. pickle.dump(baseline_tagger, print_rules)
  197. print(
  198. "Trained baseline tagger, pickled it to {0}".format(
  199. cache_baseline_tagger
  200. )
  201. )
  202. with open(cache_baseline_tagger, "r") as print_rules:
  203. baseline_tagger = pickle.load(print_rules)
  204. print("Reloaded pickled tagger from {0}".format(cache_baseline_tagger))
  205. else:
  206. baseline_tagger = UnigramTagger(baseline_data, backoff=baseline_backoff_tagger)
  207. print("Trained baseline tagger")
  208. if gold_data:
  209. print(
  210. " Accuracy on test set: {0:0.4f}".format(
  211. baseline_tagger.evaluate(gold_data)
  212. )
  213. )
  214. # creating a Brill tagger
  215. tbrill = time.time()
  216. trainer = BrillTaggerTrainer(
  217. baseline_tagger, templates, trace, ruleformat=ruleformat
  218. )
  219. print("Training tbl tagger...")
  220. brill_tagger = trainer.train(training_data, max_rules, min_score, min_acc)
  221. print("Trained tbl tagger in {0:0.2f} seconds".format(time.time() - tbrill))
  222. if gold_data:
  223. print(" Accuracy on test set: %.4f" % brill_tagger.evaluate(gold_data))
  224. # printing the learned rules, if learned silently
  225. if trace == 1:
  226. print("\nLearned rules: ")
  227. for (ruleno, rule) in enumerate(brill_tagger.rules(), 1):
  228. print("{0:4d} {1:s}".format(ruleno, rule.format(ruleformat)))
  229. # printing template statistics (optionally including comparison with the training data)
  230. # note: if not separate_baseline_data, then baseline accuracy will be artificially high
  231. if incremental_stats:
  232. print(
  233. "Incrementally tagging the test data, collecting individual rule statistics"
  234. )
  235. (taggedtest, teststats) = brill_tagger.batch_tag_incremental(
  236. testing_data, gold_data
  237. )
  238. print(" Rule statistics collected")
  239. if not separate_baseline_data:
  240. print(
  241. "WARNING: train_stats asked for separate_baseline_data=True; the baseline "
  242. "will be artificially high"
  243. )
  244. trainstats = brill_tagger.train_stats()
  245. if template_stats:
  246. brill_tagger.print_template_statistics(teststats)
  247. if learning_curve_output:
  248. _demo_plot(
  249. learning_curve_output, teststats, trainstats, take=learning_curve_take
  250. )
  251. print("Wrote plot of learning curve to {0}".format(learning_curve_output))
  252. else:
  253. print("Tagging the test data")
  254. taggedtest = brill_tagger.tag_sents(testing_data)
  255. if template_stats:
  256. brill_tagger.print_template_statistics()
  257. # writing error analysis to file
  258. if error_output is not None:
  259. with open(error_output, "w") as f:
  260. f.write("Errors for Brill Tagger %r\n\n" % serialize_output)
  261. f.write(
  262. u"\n".join(error_list(gold_data, taggedtest)).encode("utf-8") + "\n"
  263. )
  264. print("Wrote tagger errors including context to {0}".format(error_output))
  265. # serializing the tagger to a pickle file and reloading (just to see it works)
  266. if serialize_output is not None:
  267. taggedtest = brill_tagger.tag_sents(testing_data)
  268. with open(serialize_output, "w") as print_rules:
  269. pickle.dump(brill_tagger, print_rules)
  270. print("Wrote pickled tagger to {0}".format(serialize_output))
  271. with open(serialize_output, "r") as print_rules:
  272. brill_tagger_reloaded = pickle.load(print_rules)
  273. print("Reloaded pickled tagger from {0}".format(serialize_output))
  274. taggedtest_reloaded = brill_tagger.tag_sents(testing_data)
  275. if taggedtest == taggedtest_reloaded:
  276. print("Reloaded tagger tried on test set, results identical")
  277. else:
  278. print("PROBLEM: Reloaded tagger gave different results on test set")
  279. def _demo_prepare_data(
  280. tagged_data, train, num_sents, randomize, separate_baseline_data
  281. ):
  282. # train is the proportion of data used in training; the rest is reserved
  283. # for testing.
  284. if tagged_data is None:
  285. print("Loading tagged data from treebank... ")
  286. tagged_data = treebank.tagged_sents()
  287. if num_sents is None or len(tagged_data) <= num_sents:
  288. num_sents = len(tagged_data)
  289. if randomize:
  290. random.seed(len(tagged_data))
  291. random.shuffle(tagged_data)
  292. cutoff = int(num_sents * train)
  293. training_data = tagged_data[:cutoff]
  294. gold_data = tagged_data[cutoff:num_sents]
  295. testing_data = [[t[0] for t in sent] for sent in gold_data]
  296. if not separate_baseline_data:
  297. baseline_data = training_data
  298. else:
  299. bl_cutoff = len(training_data) // 3
  300. (baseline_data, training_data) = (
  301. training_data[:bl_cutoff],
  302. training_data[bl_cutoff:],
  303. )
  304. (trainseqs, traintokens) = corpus_size(training_data)
  305. (testseqs, testtokens) = corpus_size(testing_data)
  306. (bltrainseqs, bltraintokens) = corpus_size(baseline_data)
  307. print("Read testing data ({0:d} sents/{1:d} wds)".format(testseqs, testtokens))
  308. print("Read training data ({0:d} sents/{1:d} wds)".format(trainseqs, traintokens))
  309. print(
  310. "Read baseline data ({0:d} sents/{1:d} wds) {2:s}".format(
  311. bltrainseqs,
  312. bltraintokens,
  313. "" if separate_baseline_data else "[reused the training set]",
  314. )
  315. )
  316. return (training_data, baseline_data, gold_data, testing_data)
  317. def _demo_plot(learning_curve_output, teststats, trainstats=None, take=None):
  318. testcurve = [teststats["initialerrors"]]
  319. for rulescore in teststats["rulescores"]:
  320. testcurve.append(testcurve[-1] - rulescore)
  321. testcurve = [1 - x / teststats["tokencount"] for x in testcurve[:take]]
  322. traincurve = [trainstats["initialerrors"]]
  323. for rulescore in trainstats["rulescores"]:
  324. traincurve.append(traincurve[-1] - rulescore)
  325. traincurve = [1 - x / trainstats["tokencount"] for x in traincurve[:take]]
  326. import matplotlib.pyplot as plt
  327. r = list(range(len(testcurve)))
  328. plt.plot(r, testcurve, r, traincurve)
  329. plt.axis([None, None, None, 1.0])
  330. plt.savefig(learning_curve_output)
  331. NN_CD_TAGGER = RegexpTagger([(r"^-?[0-9]+(.[0-9]+)?$", "CD"), (r".*", "NN")])
  332. REGEXP_TAGGER = RegexpTagger(
  333. [
  334. (r"^-?[0-9]+(.[0-9]+)?$", "CD"), # cardinal numbers
  335. (r"(The|the|A|a|An|an)$", "AT"), # articles
  336. (r".*able$", "JJ"), # adjectives
  337. (r".*ness$", "NN"), # nouns formed from adjectives
  338. (r".*ly$", "RB"), # adverbs
  339. (r".*s$", "NNS"), # plural nouns
  340. (r".*ing$", "VBG"), # gerunds
  341. (r".*ed$", "VBD"), # past tense verbs
  342. (r".*", "NN"), # nouns (default)
  343. ]
  344. )
  345. def corpus_size(seqs):
  346. return (len(seqs), sum(len(x) for x in seqs))
  347. if __name__ == "__main__":
  348. demo_learning_curve()