test_classify.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for nltk.classify. See also: nltk/test/classify.doctest
  4. """
  5. from nose import SkipTest
  6. from nltk import classify
  7. TRAIN = [
  8. (dict(a=1, b=1, c=1), 'y'),
  9. (dict(a=1, b=1, c=1), 'x'),
  10. (dict(a=1, b=1, c=0), 'y'),
  11. (dict(a=0, b=1, c=1), 'x'),
  12. (dict(a=0, b=1, c=1), 'y'),
  13. (dict(a=0, b=0, c=1), 'y'),
  14. (dict(a=0, b=1, c=0), 'x'),
  15. (dict(a=0, b=0, c=0), 'x'),
  16. (dict(a=0, b=1, c=1), 'y'),
  17. ]
  18. TEST = [
  19. (dict(a=1, b=0, c=1)), # unseen
  20. (dict(a=1, b=0, c=0)), # unseen
  21. (dict(a=0, b=1, c=1)), # seen 3 times, labels=y,y,x
  22. (dict(a=0, b=1, c=0)), # seen 1 time, label=x
  23. ]
  24. RESULTS = [(0.16, 0.84), (0.46, 0.54), (0.41, 0.59), (0.76, 0.24)]
  25. def assert_classifier_correct(algorithm):
  26. try:
  27. classifier = classify.MaxentClassifier.train(
  28. TRAIN, algorithm, trace=0, max_iter=1000
  29. )
  30. except (LookupError, AttributeError) as e:
  31. raise SkipTest(str(e))
  32. for (px, py), featureset in zip(RESULTS, TEST):
  33. pdist = classifier.prob_classify(featureset)
  34. assert abs(pdist.prob('x') - px) < 1e-2, (pdist.prob('x'), px)
  35. assert abs(pdist.prob('y') - py) < 1e-2, (pdist.prob('y'), py)
  36. def test_megam():
  37. assert_classifier_correct('MEGAM')
  38. def test_tadm():
  39. assert_classifier_correct('TADM')