test_json_serialization.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import unittest
  2. from nltk.corpus import brown
  3. from nltk.jsontags import JSONTaggedDecoder, JSONTaggedEncoder
  4. from nltk.tag import DefaultTagger, RegexpTagger, AffixTagger
  5. from nltk.tag import UnigramTagger, BigramTagger, TrigramTagger, NgramTagger
  6. from nltk.tag import PerceptronTagger
  7. from nltk.tag import BrillTaggerTrainer, BrillTagger
  8. from nltk.tag.brill import nltkdemo18
  9. class TestJSONSerialization(unittest.TestCase):
  10. def setUp(self):
  11. self.corpus = brown.tagged_sents()[:35]
  12. self.decoder = JSONTaggedDecoder()
  13. self.encoder = JSONTaggedEncoder()
  14. self.default_tagger = DefaultTagger("NN")
  15. def test_default_tagger(self):
  16. encoded = self.encoder.encode(self.default_tagger)
  17. decoded = self.decoder.decode(encoded)
  18. self.assertEqual(repr(self.default_tagger), repr(decoded))
  19. self.assertEqual(self.default_tagger._tag, decoded._tag)
  20. def test_regexp_tagger(self):
  21. tagger = RegexpTagger([(r".*", "NN")], backoff=self.default_tagger)
  22. encoded = self.encoder.encode(tagger)
  23. decoded = self.decoder.decode(encoded)
  24. self.assertEqual(repr(tagger), repr(decoded))
  25. self.assertEqual(repr(tagger.backoff), repr(decoded.backoff))
  26. self.assertEqual(tagger._regexps, decoded._regexps)
  27. def test_affix_tagger(self):
  28. tagger = AffixTagger(self.corpus, backoff=self.default_tagger)
  29. encoded = self.encoder.encode(tagger)
  30. decoded = self.decoder.decode(encoded)
  31. self.assertEqual(repr(tagger), repr(decoded))
  32. self.assertEqual(repr(tagger.backoff), repr(decoded.backoff))
  33. self.assertEqual(tagger._affix_length, decoded._affix_length)
  34. self.assertEqual(tagger._min_word_length, decoded._min_word_length)
  35. self.assertEqual(tagger._context_to_tag, decoded._context_to_tag)
  36. def test_ngram_taggers(self):
  37. unitagger = UnigramTagger(self.corpus, backoff=self.default_tagger)
  38. bitagger = BigramTagger(self.corpus, backoff=unitagger)
  39. tritagger = TrigramTagger(self.corpus, backoff=bitagger)
  40. ntagger = NgramTagger(4, self.corpus, backoff=tritagger)
  41. encoded = self.encoder.encode(ntagger)
  42. decoded = self.decoder.decode(encoded)
  43. self.assertEqual(repr(ntagger), repr(decoded))
  44. self.assertEqual(repr(tritagger), repr(decoded.backoff))
  45. self.assertEqual(repr(bitagger), repr(decoded.backoff.backoff))
  46. self.assertEqual(repr(unitagger), repr(decoded.backoff.backoff.backoff))
  47. self.assertEqual(repr(self.default_tagger),
  48. repr(decoded.backoff.backoff.backoff.backoff))
  49. def test_perceptron_tagger(self):
  50. tagger = PerceptronTagger(load=False)
  51. tagger.train(self.corpus)
  52. encoded = self.encoder.encode(tagger)
  53. decoded = self.decoder.decode(encoded)
  54. self.assertEqual(tagger.model.weights, decoded.model.weights)
  55. self.assertEqual(tagger.tagdict, decoded.tagdict)
  56. self.assertEqual(tagger.classes, decoded.classes)
  57. def test_brill_tagger(self):
  58. trainer = BrillTaggerTrainer(self.default_tagger, nltkdemo18(),
  59. deterministic=True)
  60. tagger = trainer.train(self.corpus, max_rules=30)
  61. encoded = self.encoder.encode(tagger)
  62. decoded = self.decoder.decode(encoded)
  63. self.assertEqual(repr(tagger._initial_tagger),
  64. repr(decoded._initial_tagger))
  65. self.assertEqual(tagger._rules, decoded._rules)
  66. self.assertEqual(tagger._training_stats, decoded._training_stats)