weka.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Natural Language Toolkit: Interface to Weka Classsifiers
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. Classifiers that make use of the external 'Weka' package.
  9. """
  10. import time
  11. import tempfile
  12. import os
  13. import subprocess
  14. import re
  15. import zipfile
  16. from sys import stdin
  17. from nltk.probability import DictionaryProbDist
  18. from nltk.internals import java, config_java
  19. from nltk.classify.api import ClassifierI
  20. _weka_classpath = None
  21. _weka_search = [
  22. ".",
  23. "/usr/share/weka",
  24. "/usr/local/share/weka",
  25. "/usr/lib/weka",
  26. "/usr/local/lib/weka",
  27. ]
  28. def config_weka(classpath=None):
  29. global _weka_classpath
  30. # Make sure java's configured first.
  31. config_java()
  32. if classpath is not None:
  33. _weka_classpath = classpath
  34. if _weka_classpath is None:
  35. searchpath = _weka_search
  36. if "WEKAHOME" in os.environ:
  37. searchpath.insert(0, os.environ["WEKAHOME"])
  38. for path in searchpath:
  39. if os.path.exists(os.path.join(path, "weka.jar")):
  40. _weka_classpath = os.path.join(path, "weka.jar")
  41. version = _check_weka_version(_weka_classpath)
  42. if version:
  43. print(
  44. ("[Found Weka: %s (version %s)]" % (_weka_classpath, version))
  45. )
  46. else:
  47. print("[Found Weka: %s]" % _weka_classpath)
  48. _check_weka_version(_weka_classpath)
  49. if _weka_classpath is None:
  50. raise LookupError(
  51. "Unable to find weka.jar! Use config_weka() "
  52. "or set the WEKAHOME environment variable. "
  53. "For more information about Weka, please see "
  54. "http://www.cs.waikato.ac.nz/ml/weka/"
  55. )
  56. def _check_weka_version(jar):
  57. try:
  58. zf = zipfile.ZipFile(jar)
  59. except (SystemExit, KeyboardInterrupt):
  60. raise
  61. except:
  62. return None
  63. try:
  64. try:
  65. return zf.read("weka/core/version.txt")
  66. except KeyError:
  67. return None
  68. finally:
  69. zf.close()
  70. class WekaClassifier(ClassifierI):
  71. def __init__(self, formatter, model_filename):
  72. self._formatter = formatter
  73. self._model = model_filename
  74. def prob_classify_many(self, featuresets):
  75. return self._classify_many(featuresets, ["-p", "0", "-distribution"])
  76. def classify_many(self, featuresets):
  77. return self._classify_many(featuresets, ["-p", "0"])
  78. def _classify_many(self, featuresets, options):
  79. # Make sure we can find java & weka.
  80. config_weka()
  81. temp_dir = tempfile.mkdtemp()
  82. try:
  83. # Write the test data file.
  84. test_filename = os.path.join(temp_dir, "test.arff")
  85. self._formatter.write(test_filename, featuresets)
  86. # Call weka to classify the data.
  87. cmd = [
  88. "weka.classifiers.bayes.NaiveBayes",
  89. "-l",
  90. self._model,
  91. "-T",
  92. test_filename,
  93. ] + options
  94. (stdout, stderr) = java(
  95. cmd,
  96. classpath=_weka_classpath,
  97. stdout=subprocess.PIPE,
  98. stderr=subprocess.PIPE,
  99. )
  100. # Check if something went wrong:
  101. if stderr and not stdout:
  102. if "Illegal options: -distribution" in stderr:
  103. raise ValueError(
  104. "The installed version of weka does "
  105. "not support probability distribution "
  106. "output."
  107. )
  108. else:
  109. raise ValueError("Weka failed to generate output:\n%s" % stderr)
  110. # Parse weka's output.
  111. return self.parse_weka_output(stdout.decode(stdin.encoding).split("\n"))
  112. finally:
  113. for f in os.listdir(temp_dir):
  114. os.remove(os.path.join(temp_dir, f))
  115. os.rmdir(temp_dir)
  116. def parse_weka_distribution(self, s):
  117. probs = [float(v) for v in re.split("[*,]+", s) if v.strip()]
  118. probs = dict(zip(self._formatter.labels(), probs))
  119. return DictionaryProbDist(probs)
  120. def parse_weka_output(self, lines):
  121. # Strip unwanted text from stdout
  122. for i, line in enumerate(lines):
  123. if line.strip().startswith("inst#"):
  124. lines = lines[i:]
  125. break
  126. if lines[0].split() == ["inst#", "actual", "predicted", "error", "prediction"]:
  127. return [line.split()[2].split(":")[1] for line in lines[1:] if line.strip()]
  128. elif lines[0].split() == [
  129. "inst#",
  130. "actual",
  131. "predicted",
  132. "error",
  133. "distribution",
  134. ]:
  135. return [
  136. self.parse_weka_distribution(line.split()[-1])
  137. for line in lines[1:]
  138. if line.strip()
  139. ]
  140. # is this safe:?
  141. elif re.match(r"^0 \w+ [01]\.[0-9]* \?\s*$", lines[0]):
  142. return [line.split()[1] for line in lines if line.strip()]
  143. else:
  144. for line in lines[:10]:
  145. print(line)
  146. raise ValueError(
  147. "Unhandled output format -- your version "
  148. "of weka may not be supported.\n"
  149. " Header: %s" % lines[0]
  150. )
  151. # [xx] full list of classifiers (some may be abstract?):
  152. # ADTree, AODE, BayesNet, ComplementNaiveBayes, ConjunctiveRule,
  153. # DecisionStump, DecisionTable, HyperPipes, IB1, IBk, Id3, J48,
  154. # JRip, KStar, LBR, LeastMedSq, LinearRegression, LMT, Logistic,
  155. # LogisticBase, M5Base, MultilayerPerceptron,
  156. # MultipleClassifiersCombiner, NaiveBayes, NaiveBayesMultinomial,
  157. # NaiveBayesSimple, NBTree, NNge, OneR, PaceRegression, PART,
  158. # PreConstructedLinearModel, Prism, RandomForest,
  159. # RandomizableClassifier, RandomTree, RBFNetwork, REPTree, Ridor,
  160. # RuleNode, SimpleLinearRegression, SimpleLogistic,
  161. # SingleClassifierEnhancer, SMO, SMOreg, UserClassifier, VFI,
  162. # VotedPerceptron, Winnow, ZeroR
  163. _CLASSIFIER_CLASS = {
  164. "naivebayes": "weka.classifiers.bayes.NaiveBayes",
  165. "C4.5": "weka.classifiers.trees.J48",
  166. "log_regression": "weka.classifiers.functions.Logistic",
  167. "svm": "weka.classifiers.functions.SMO",
  168. "kstar": "weka.classifiers.lazy.KStar",
  169. "ripper": "weka.classifiers.rules.JRip",
  170. }
  171. @classmethod
  172. def train(
  173. cls,
  174. model_filename,
  175. featuresets,
  176. classifier="naivebayes",
  177. options=[],
  178. quiet=True,
  179. ):
  180. # Make sure we can find java & weka.
  181. config_weka()
  182. # Build an ARFF formatter.
  183. formatter = ARFF_Formatter.from_train(featuresets)
  184. temp_dir = tempfile.mkdtemp()
  185. try:
  186. # Write the training data file.
  187. train_filename = os.path.join(temp_dir, "train.arff")
  188. formatter.write(train_filename, featuresets)
  189. if classifier in cls._CLASSIFIER_CLASS:
  190. javaclass = cls._CLASSIFIER_CLASS[classifier]
  191. elif classifier in cls._CLASSIFIER_CLASS.values():
  192. javaclass = classifier
  193. else:
  194. raise ValueError("Unknown classifier %s" % classifier)
  195. # Train the weka model.
  196. cmd = [javaclass, "-d", model_filename, "-t", train_filename]
  197. cmd += list(options)
  198. if quiet:
  199. stdout = subprocess.PIPE
  200. else:
  201. stdout = None
  202. java(cmd, classpath=_weka_classpath, stdout=stdout)
  203. # Return the new classifier.
  204. return WekaClassifier(formatter, model_filename)
  205. finally:
  206. for f in os.listdir(temp_dir):
  207. os.remove(os.path.join(temp_dir, f))
  208. os.rmdir(temp_dir)
  209. class ARFF_Formatter:
  210. """
  211. Converts featuresets and labeled featuresets to ARFF-formatted
  212. strings, appropriate for input into Weka.
  213. Features and classes can be specified manually in the constructor, or may
  214. be determined from data using ``from_train``.
  215. """
  216. def __init__(self, labels, features):
  217. """
  218. :param labels: A list of all class labels that can be generated.
  219. :param features: A list of feature specifications, where
  220. each feature specification is a tuple (fname, ftype);
  221. and ftype is an ARFF type string such as NUMERIC or
  222. STRING.
  223. """
  224. self._labels = labels
  225. self._features = features
  226. def format(self, tokens):
  227. """Returns a string representation of ARFF output for the given data."""
  228. return self.header_section() + self.data_section(tokens)
  229. def labels(self):
  230. """Returns the list of classes."""
  231. return list(self._labels)
  232. def write(self, outfile, tokens):
  233. """Writes ARFF data to a file for the given data."""
  234. if not hasattr(outfile, "write"):
  235. outfile = open(outfile, "w")
  236. outfile.write(self.format(tokens))
  237. outfile.close()
  238. @staticmethod
  239. def from_train(tokens):
  240. """
  241. Constructs an ARFF_Formatter instance with class labels and feature
  242. types determined from the given data. Handles boolean, numeric and
  243. string (note: not nominal) types.
  244. """
  245. # Find the set of all attested labels.
  246. labels = set(label for (tok, label) in tokens)
  247. # Determine the types of all features.
  248. features = {}
  249. for tok, label in tokens:
  250. for (fname, fval) in tok.items():
  251. if issubclass(type(fval), bool):
  252. ftype = "{True, False}"
  253. elif issubclass(type(fval), (int, float, bool)):
  254. ftype = "NUMERIC"
  255. elif issubclass(type(fval), str):
  256. ftype = "STRING"
  257. elif fval is None:
  258. continue # can't tell the type.
  259. else:
  260. raise ValueError("Unsupported value type %r" % ftype)
  261. if features.get(fname, ftype) != ftype:
  262. raise ValueError("Inconsistent type for %s" % fname)
  263. features[fname] = ftype
  264. features = sorted(features.items())
  265. return ARFF_Formatter(labels, features)
  266. def header_section(self):
  267. """Returns an ARFF header as a string."""
  268. # Header comment.
  269. s = (
  270. "% Weka ARFF file\n"
  271. + "% Generated automatically by NLTK\n"
  272. + "%% %s\n\n" % time.ctime()
  273. )
  274. # Relation name
  275. s += "@RELATION rel\n\n"
  276. # Input attribute specifications
  277. for fname, ftype in self._features:
  278. s += "@ATTRIBUTE %-30r %s\n" % (fname, ftype)
  279. # Label attribute specification
  280. s += "@ATTRIBUTE %-30r {%s}\n" % ("-label-", ",".join(self._labels))
  281. return s
  282. def data_section(self, tokens, labeled=None):
  283. """
  284. Returns the ARFF data section for the given data.
  285. :param tokens: a list of featuresets (dicts) or labelled featuresets
  286. which are tuples (featureset, label).
  287. :param labeled: Indicates whether the given tokens are labeled
  288. or not. If None, then the tokens will be assumed to be
  289. labeled if the first token's value is a tuple or list.
  290. """
  291. # Check if the tokens are labeled or unlabeled. If unlabeled,
  292. # then use 'None'
  293. if labeled is None:
  294. labeled = tokens and isinstance(tokens[0], (tuple, list))
  295. if not labeled:
  296. tokens = [(tok, None) for tok in tokens]
  297. # Data section
  298. s = "\n@DATA\n"
  299. for (tok, label) in tokens:
  300. for fname, ftype in self._features:
  301. s += "%s," % self._fmt_arff_val(tok.get(fname))
  302. s += "%s\n" % self._fmt_arff_val(label)
  303. return s
  304. def _fmt_arff_val(self, fval):
  305. if fval is None:
  306. return "?"
  307. elif isinstance(fval, (bool, int)):
  308. return "%s" % fval
  309. elif isinstance(fval, float):
  310. return "%r" % fval
  311. else:
  312. return "%r" % fval
  313. if __name__ == "__main__":
  314. from nltk.classify.util import names_demo, binary_names_demo_features
  315. def make_classifier(featuresets):
  316. return WekaClassifier.train("/tmp/name.model", featuresets, "C4.5")
  317. classifier = names_demo(make_classifier, binary_names_demo_features)