test_senna.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for Senna
  4. """
  5. from os import environ, path, sep
  6. import logging
  7. import unittest
  8. from nltk.classify import Senna
  9. from nltk.tag import SennaTagger, SennaChunkTagger, SennaNERTagger
  10. # Set Senna executable path for tests if it is not specified as an environment variable
  11. if 'SENNA' in environ:
  12. SENNA_EXECUTABLE_PATH = path.normpath(environ['SENNA']) + sep
  13. else:
  14. SENNA_EXECUTABLE_PATH = '/usr/share/senna-v3.0'
  15. senna_is_installed = path.exists(SENNA_EXECUTABLE_PATH)
  16. @unittest.skipUnless(senna_is_installed, "Requires Senna executable")
  17. class TestSennaPipeline(unittest.TestCase):
  18. """Unittest for nltk.classify.senna"""
  19. def test_senna_pipeline(self):
  20. """Senna pipeline interface"""
  21. pipeline = Senna(SENNA_EXECUTABLE_PATH, ['pos', 'chk', 'ner'])
  22. sent = 'Dusseldorf is an international business center'.split()
  23. result = [
  24. (token['word'], token['chk'], token['ner'], token['pos'])
  25. for token in pipeline.tag(sent)
  26. ]
  27. expected = [
  28. ('Dusseldorf', 'B-NP', 'B-LOC', 'NNP'),
  29. ('is', 'B-VP', 'O', 'VBZ'),
  30. ('an', 'B-NP', 'O', 'DT'),
  31. ('international', 'I-NP', 'O', 'JJ'),
  32. ('business', 'I-NP', 'O', 'NN'),
  33. ('center', 'I-NP', 'O', 'NN'),
  34. ]
  35. self.assertEqual(result, expected)
  36. @unittest.skipUnless(senna_is_installed, "Requires Senna executable")
  37. class TestSennaTagger(unittest.TestCase):
  38. """Unittest for nltk.tag.senna"""
  39. def test_senna_tagger(self):
  40. tagger = SennaTagger(SENNA_EXECUTABLE_PATH)
  41. result = tagger.tag('What is the airspeed of an unladen swallow ?'.split())
  42. expected = [
  43. ('What', 'WP'),
  44. ('is', 'VBZ'),
  45. ('the', 'DT'),
  46. ('airspeed', 'NN'),
  47. ('of', 'IN'),
  48. ('an', 'DT'),
  49. ('unladen', 'NN'),
  50. ('swallow', 'NN'),
  51. ('?', '.'),
  52. ]
  53. self.assertEqual(result, expected)
  54. def test_senna_chunk_tagger(self):
  55. chktagger = SennaChunkTagger(SENNA_EXECUTABLE_PATH)
  56. result_1 = chktagger.tag('What is the airspeed of an unladen swallow ?'.split())
  57. expected_1 = [
  58. ('What', 'B-NP'),
  59. ('is', 'B-VP'),
  60. ('the', 'B-NP'),
  61. ('airspeed', 'I-NP'),
  62. ('of', 'B-PP'),
  63. ('an', 'B-NP'),
  64. ('unladen', 'I-NP'),
  65. ('swallow', 'I-NP'),
  66. ('?', 'O'),
  67. ]
  68. result_2 = list(chktagger.bio_to_chunks(result_1, chunk_type='NP'))
  69. expected_2 = [
  70. ('What', '0'),
  71. ('the airspeed', '2-3'),
  72. ('an unladen swallow', '5-6-7'),
  73. ]
  74. self.assertEqual(result_1, expected_1)
  75. self.assertEqual(result_2, expected_2)
  76. def test_senna_ner_tagger(self):
  77. nertagger = SennaNERTagger(SENNA_EXECUTABLE_PATH)
  78. result_1 = nertagger.tag('Shakespeare theatre was in London .'.split())
  79. expected_1 = [
  80. ('Shakespeare', 'B-PER'),
  81. ('theatre', 'O'),
  82. ('was', 'O'),
  83. ('in', 'O'),
  84. ('London', 'B-LOC'),
  85. ('.', 'O'),
  86. ]
  87. result_2 = nertagger.tag('UN headquarters are in NY , USA .'.split())
  88. expected_2 = [
  89. ('UN', 'B-ORG'),
  90. ('headquarters', 'O'),
  91. ('are', 'O'),
  92. ('in', 'O'),
  93. ('NY', 'B-LOC'),
  94. (',', 'O'),
  95. ('USA', 'B-LOC'),
  96. ('.', 'O'),
  97. ]
  98. self.assertEqual(result_1, expected_1)
  99. self.assertEqual(result_2, expected_2)