test_hmm.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -*- coding: utf-8 -*-
  2. from nltk.tag import hmm
  3. def _wikipedia_example_hmm():
  4. # Example from wikipedia
  5. # (http://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm)
  6. states = ['rain', 'no rain']
  7. symbols = ['umbrella', 'no umbrella']
  8. A = [[0.7, 0.3], [0.3, 0.7]] # transition probabilities
  9. B = [[0.9, 0.1], [0.2, 0.8]] # emission probabilities
  10. pi = [0.5, 0.5] # initial probabilities
  11. seq = ['umbrella', 'umbrella', 'no umbrella', 'umbrella', 'umbrella']
  12. seq = list(zip(seq, [None] * len(seq)))
  13. model = hmm._create_hmm_tagger(states, symbols, A, B, pi)
  14. return model, states, symbols, seq
  15. def test_forward_probability():
  16. from numpy.testing import assert_array_almost_equal
  17. # example from p. 385, Huang et al
  18. model, states, symbols = hmm._market_hmm_example()
  19. seq = [('up', None), ('up', None)]
  20. expected = [[0.35, 0.02, 0.09], [0.1792, 0.0085, 0.0357]]
  21. fp = 2 ** model._forward_probability(seq)
  22. assert_array_almost_equal(fp, expected)
  23. def test_forward_probability2():
  24. from numpy.testing import assert_array_almost_equal
  25. model, states, symbols, seq = _wikipedia_example_hmm()
  26. fp = 2 ** model._forward_probability(seq)
  27. # examples in wikipedia are normalized
  28. fp = (fp.T / fp.sum(axis=1)).T
  29. wikipedia_results = [
  30. [0.8182, 0.1818],
  31. [0.8834, 0.1166],
  32. [0.1907, 0.8093],
  33. [0.7308, 0.2692],
  34. [0.8673, 0.1327],
  35. ]
  36. assert_array_almost_equal(wikipedia_results, fp, 4)
  37. def test_backward_probability():
  38. from numpy.testing import assert_array_almost_equal
  39. model, states, symbols, seq = _wikipedia_example_hmm()
  40. bp = 2 ** model._backward_probability(seq)
  41. # examples in wikipedia are normalized
  42. bp = (bp.T / bp.sum(axis=1)).T
  43. wikipedia_results = [
  44. # Forward-backward algorithm doesn't need b0_5,
  45. # so .backward_probability doesn't compute it.
  46. # [0.6469, 0.3531],
  47. [0.5923, 0.4077],
  48. [0.3763, 0.6237],
  49. [0.6533, 0.3467],
  50. [0.6273, 0.3727],
  51. [0.5, 0.5],
  52. ]
  53. assert_array_almost_equal(wikipedia_results, bp, 4)
  54. def setup_module(module):
  55. from nose import SkipTest
  56. try:
  57. import numpy
  58. except ImportError:
  59. raise SkipTest("numpy is required for nltk.test.test_hmm")