tadm.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Natural Language Toolkit: Interface to TADM Classifier
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Joseph Frazee <jfrazee@mail.utexas.edu>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. import sys
  8. import subprocess
  9. from nltk.internals import find_binary
  10. try:
  11. import numpy
  12. except ImportError:
  13. pass
  14. _tadm_bin = None
  15. def config_tadm(bin=None):
  16. global _tadm_bin
  17. _tadm_bin = find_binary(
  18. "tadm", bin, env_vars=["TADM"], binary_names=["tadm"], url="http://tadm.sf.net"
  19. )
  20. def write_tadm_file(train_toks, encoding, stream):
  21. """
  22. Generate an input file for ``tadm`` based on the given corpus of
  23. classified tokens.
  24. :type train_toks: list(tuple(dict, str))
  25. :param train_toks: Training data, represented as a list of
  26. pairs, the first member of which is a feature dictionary,
  27. and the second of which is a classification label.
  28. :type encoding: TadmEventMaxentFeatureEncoding
  29. :param encoding: A feature encoding, used to convert featuresets
  30. into feature vectors.
  31. :type stream: stream
  32. :param stream: The stream to which the ``tadm`` input file should be
  33. written.
  34. """
  35. # See the following for a file format description:
  36. #
  37. # http://sf.net/forum/forum.php?thread_id=1391502&forum_id=473054
  38. # http://sf.net/forum/forum.php?thread_id=1675097&forum_id=473054
  39. labels = encoding.labels()
  40. for featureset, label in train_toks:
  41. length_line = "%d\n" % len(labels)
  42. stream.write(length_line)
  43. for known_label in labels:
  44. v = encoding.encode(featureset, known_label)
  45. line = "%d %d %s\n" % (
  46. int(label == known_label),
  47. len(v),
  48. " ".join("%d %d" % u for u in v),
  49. )
  50. stream.write(line)
  51. def parse_tadm_weights(paramfile):
  52. """
  53. Given the stdout output generated by ``tadm`` when training a
  54. model, return a ``numpy`` array containing the corresponding weight
  55. vector.
  56. """
  57. weights = []
  58. for line in paramfile:
  59. weights.append(float(line.strip()))
  60. return numpy.array(weights, "d")
  61. def call_tadm(args):
  62. """
  63. Call the ``tadm`` binary with the given arguments.
  64. """
  65. if isinstance(args, str):
  66. raise TypeError("args should be a list of strings")
  67. if _tadm_bin is None:
  68. config_tadm()
  69. # Call tadm via a subprocess
  70. cmd = [_tadm_bin] + args
  71. p = subprocess.Popen(cmd, stdout=sys.stdout)
  72. (stdout, stderr) = p.communicate()
  73. # Check the return code.
  74. if p.returncode != 0:
  75. print()
  76. print(stderr)
  77. raise OSError("tadm command failed!")
  78. def names_demo():
  79. from nltk.classify.util import names_demo
  80. from nltk.classify.maxent import TadmMaxentClassifier
  81. classifier = names_demo(TadmMaxentClassifier.train)
  82. def encoding_demo():
  83. import sys
  84. from nltk.classify.maxent import TadmEventMaxentFeatureEncoding
  85. tokens = [
  86. ({"f0": 1, "f1": 1, "f3": 1}, "A"),
  87. ({"f0": 1, "f2": 1, "f4": 1}, "B"),
  88. ({"f0": 2, "f2": 1, "f3": 1, "f4": 1}, "A"),
  89. ]
  90. encoding = TadmEventMaxentFeatureEncoding.train(tokens)
  91. write_tadm_file(tokens, encoding, sys.stdout)
  92. print()
  93. for i in range(encoding.length()):
  94. print("%s --> %d" % (encoding.describe(i), i))
  95. print()
  96. if __name__ == "__main__":
  97. encoding_demo()
  98. names_demo()