test_rte_classify.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # -*- coding: utf-8 -*-
  2. import unittest
  3. from nltk.corpus import rte as rte_corpus
  4. from nltk.classify.rte_classify import RTEFeatureExtractor, rte_features, rte_classifier
  5. expected_from_rte_feature_extration = """
  6. alwayson => True
  7. ne_hyp_extra => 0
  8. ne_overlap => 1
  9. neg_hyp => 0
  10. neg_txt => 0
  11. word_hyp_extra => 3
  12. word_overlap => 3
  13. alwayson => True
  14. ne_hyp_extra => 0
  15. ne_overlap => 1
  16. neg_hyp => 0
  17. neg_txt => 0
  18. word_hyp_extra => 2
  19. word_overlap => 1
  20. alwayson => True
  21. ne_hyp_extra => 1
  22. ne_overlap => 1
  23. neg_hyp => 0
  24. neg_txt => 0
  25. word_hyp_extra => 1
  26. word_overlap => 2
  27. alwayson => True
  28. ne_hyp_extra => 1
  29. ne_overlap => 0
  30. neg_hyp => 0
  31. neg_txt => 0
  32. word_hyp_extra => 6
  33. word_overlap => 2
  34. alwayson => True
  35. ne_hyp_extra => 1
  36. ne_overlap => 0
  37. neg_hyp => 0
  38. neg_txt => 0
  39. word_hyp_extra => 4
  40. word_overlap => 0
  41. alwayson => True
  42. ne_hyp_extra => 1
  43. ne_overlap => 0
  44. neg_hyp => 0
  45. neg_txt => 0
  46. word_hyp_extra => 3
  47. word_overlap => 1
  48. """
  49. class RTEClassifierTest(unittest.TestCase):
  50. # Test the feature extraction method.
  51. def test_rte_feature_extraction(self):
  52. pairs = rte_corpus.pairs(['rte1_dev.xml'])[:6]
  53. test_output = [
  54. "%-15s => %s" % (key, rte_features(pair)[key])
  55. for pair in pairs
  56. for key in sorted(rte_features(pair))
  57. ]
  58. expected_output = expected_from_rte_feature_extration.strip().split('\n')
  59. # Remove null strings.
  60. expected_output = list(filter(None, expected_output))
  61. self.assertEqual(test_output, expected_output)
  62. # Test the RTEFeatureExtractor object.
  63. def test_feature_extractor_object(self):
  64. rtepair = rte_corpus.pairs(['rte3_dev.xml'])[33]
  65. extractor = RTEFeatureExtractor(rtepair)
  66. self.assertEqual(extractor.hyp_words, {'member', 'China', 'SCO.'})
  67. self.assertEqual(extractor.overlap('word'), set())
  68. self.assertEqual(extractor.overlap('ne'), {'China'})
  69. self.assertEqual(extractor.hyp_extra('word'), {'member'})
  70. # Test the RTE classifier training.
  71. def test_rte_classification_without_megam(self):
  72. clf = rte_classifier('IIS')
  73. clf = rte_classifier('GIS')
  74. @unittest.skip("Skipping tests with dependencies on MEGAM")
  75. def test_rte_classification_with_megam(self):
  76. nltk.config_megam('/usr/local/bin/megam')
  77. clf = rte_classifier('megam')
  78. clf = rte_classifier('BFGS')