named_entity.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. # Natural Language Toolkit: Chunk parsing API
  2. #
  3. # Copyright (C) 2001-2020 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. Named entity chunker
  9. """
  10. import os, re, pickle
  11. from xml.etree import ElementTree as ET
  12. from nltk.tag import ClassifierBasedTagger, pos_tag
  13. try:
  14. from nltk.classify import MaxentClassifier
  15. except ImportError:
  16. pass
  17. from nltk.tree import Tree
  18. from nltk.tokenize import word_tokenize
  19. from nltk.data import find
  20. from nltk.chunk.api import ChunkParserI
  21. from nltk.chunk.util import ChunkScore
  22. class NEChunkParserTagger(ClassifierBasedTagger):
  23. """
  24. The IOB tagger used by the chunk parser.
  25. """
  26. def __init__(self, train):
  27. ClassifierBasedTagger.__init__(
  28. self, train=train, classifier_builder=self._classifier_builder
  29. )
  30. def _classifier_builder(self, train):
  31. return MaxentClassifier.train(
  32. train, algorithm="megam", gaussian_prior_sigma=1, trace=2
  33. )
  34. def _english_wordlist(self):
  35. try:
  36. wl = self._en_wordlist
  37. except AttributeError:
  38. from nltk.corpus import words
  39. self._en_wordlist = set(words.words("en-basic"))
  40. wl = self._en_wordlist
  41. return wl
  42. def _feature_detector(self, tokens, index, history):
  43. word = tokens[index][0]
  44. pos = simplify_pos(tokens[index][1])
  45. if index == 0:
  46. prevword = prevprevword = None
  47. prevpos = prevprevpos = None
  48. prevshape = prevtag = prevprevtag = None
  49. elif index == 1:
  50. prevword = tokens[index - 1][0].lower()
  51. prevprevword = None
  52. prevpos = simplify_pos(tokens[index - 1][1])
  53. prevprevpos = None
  54. prevtag = history[index - 1][0]
  55. prevshape = prevprevtag = None
  56. else:
  57. prevword = tokens[index - 1][0].lower()
  58. prevprevword = tokens[index - 2][0].lower()
  59. prevpos = simplify_pos(tokens[index - 1][1])
  60. prevprevpos = simplify_pos(tokens[index - 2][1])
  61. prevtag = history[index - 1]
  62. prevprevtag = history[index - 2]
  63. prevshape = shape(prevword)
  64. if index == len(tokens) - 1:
  65. nextword = nextnextword = None
  66. nextpos = nextnextpos = None
  67. elif index == len(tokens) - 2:
  68. nextword = tokens[index + 1][0].lower()
  69. nextpos = tokens[index + 1][1].lower()
  70. nextnextword = None
  71. nextnextpos = None
  72. else:
  73. nextword = tokens[index + 1][0].lower()
  74. nextpos = tokens[index + 1][1].lower()
  75. nextnextword = tokens[index + 2][0].lower()
  76. nextnextpos = tokens[index + 2][1].lower()
  77. # 89.6
  78. features = {
  79. "bias": True,
  80. "shape": shape(word),
  81. "wordlen": len(word),
  82. "prefix3": word[:3].lower(),
  83. "suffix3": word[-3:].lower(),
  84. "pos": pos,
  85. "word": word,
  86. "en-wordlist": (word in self._english_wordlist()),
  87. "prevtag": prevtag,
  88. "prevpos": prevpos,
  89. "nextpos": nextpos,
  90. "prevword": prevword,
  91. "nextword": nextword,
  92. "word+nextpos": "{0}+{1}".format(word.lower(), nextpos),
  93. "pos+prevtag": "{0}+{1}".format(pos, prevtag),
  94. "shape+prevtag": "{0}+{1}".format(prevshape, prevtag),
  95. }
  96. return features
  97. class NEChunkParser(ChunkParserI):
  98. """
  99. Expected input: list of pos-tagged words
  100. """
  101. def __init__(self, train):
  102. self._train(train)
  103. def parse(self, tokens):
  104. """
  105. Each token should be a pos-tagged word
  106. """
  107. tagged = self._tagger.tag(tokens)
  108. tree = self._tagged_to_parse(tagged)
  109. return tree
  110. def _train(self, corpus):
  111. # Convert to tagged sequence
  112. corpus = [self._parse_to_tagged(s) for s in corpus]
  113. self._tagger = NEChunkParserTagger(train=corpus)
  114. def _tagged_to_parse(self, tagged_tokens):
  115. """
  116. Convert a list of tagged tokens to a chunk-parse tree.
  117. """
  118. sent = Tree("S", [])
  119. for (tok, tag) in tagged_tokens:
  120. if tag == "O":
  121. sent.append(tok)
  122. elif tag.startswith("B-"):
  123. sent.append(Tree(tag[2:], [tok]))
  124. elif tag.startswith("I-"):
  125. if sent and isinstance(sent[-1], Tree) and sent[-1].label() == tag[2:]:
  126. sent[-1].append(tok)
  127. else:
  128. sent.append(Tree(tag[2:], [tok]))
  129. return sent
  130. @staticmethod
  131. def _parse_to_tagged(sent):
  132. """
  133. Convert a chunk-parse tree to a list of tagged tokens.
  134. """
  135. toks = []
  136. for child in sent:
  137. if isinstance(child, Tree):
  138. if len(child) == 0:
  139. print("Warning -- empty chunk in sentence")
  140. continue
  141. toks.append((child[0], "B-{0}".format(child.label())))
  142. for tok in child[1:]:
  143. toks.append((tok, "I-{0}".format(child.label())))
  144. else:
  145. toks.append((child, "O"))
  146. return toks
  147. def shape(word):
  148. if re.match("[0-9]+(\.[0-9]*)?|[0-9]*\.[0-9]+$", word, re.UNICODE):
  149. return "number"
  150. elif re.match("\W+$", word, re.UNICODE):
  151. return "punct"
  152. elif re.match("\w+$", word, re.UNICODE):
  153. if word.istitle():
  154. return "upcase"
  155. elif word.islower():
  156. return "downcase"
  157. else:
  158. return "mixedcase"
  159. else:
  160. return "other"
  161. def simplify_pos(s):
  162. if s.startswith("V"):
  163. return "V"
  164. else:
  165. return s.split("-")[0]
  166. def postag_tree(tree):
  167. # Part-of-speech tagging.
  168. words = tree.leaves()
  169. tag_iter = (pos for (word, pos) in pos_tag(words))
  170. newtree = Tree("S", [])
  171. for child in tree:
  172. if isinstance(child, Tree):
  173. newtree.append(Tree(child.label(), []))
  174. for subchild in child:
  175. newtree[-1].append((subchild, next(tag_iter)))
  176. else:
  177. newtree.append((child, next(tag_iter)))
  178. return newtree
  179. def load_ace_data(roots, fmt="binary", skip_bnews=True):
  180. for root in roots:
  181. for root, dirs, files in os.walk(root):
  182. if root.endswith("bnews") and skip_bnews:
  183. continue
  184. for f in files:
  185. if f.endswith(".sgm"):
  186. for sent in load_ace_file(os.path.join(root, f), fmt):
  187. yield sent
  188. def load_ace_file(textfile, fmt):
  189. print(" - {0}".format(os.path.split(textfile)[1]))
  190. annfile = textfile + ".tmx.rdc.xml"
  191. # Read the xml file, and get a list of entities
  192. entities = []
  193. with open(annfile, "r") as infile:
  194. xml = ET.parse(infile).getroot()
  195. for entity in xml.findall("document/entity"):
  196. typ = entity.find("entity_type").text
  197. for mention in entity.findall("entity_mention"):
  198. if mention.get("TYPE") != "NAME":
  199. continue # only NEs
  200. s = int(mention.find("head/charseq/start").text)
  201. e = int(mention.find("head/charseq/end").text) + 1
  202. entities.append((s, e, typ))
  203. # Read the text file, and mark the entities.
  204. with open(textfile, "r") as infile:
  205. text = infile.read()
  206. # Strip XML tags, since they don't count towards the indices
  207. text = re.sub("<(?!/?TEXT)[^>]+>", "", text)
  208. # Blank out anything before/after <TEXT>
  209. def subfunc(m):
  210. return " " * (m.end() - m.start() - 6)
  211. text = re.sub("[\s\S]*<TEXT>", subfunc, text)
  212. text = re.sub("</TEXT>[\s\S]*", "", text)
  213. # Simplify quotes
  214. text = re.sub("``", ' "', text)
  215. text = re.sub("''", '" ', text)
  216. entity_types = set(typ for (s, e, typ) in entities)
  217. # Binary distinction (NE or not NE)
  218. if fmt == "binary":
  219. i = 0
  220. toks = Tree("S", [])
  221. for (s, e, typ) in sorted(entities):
  222. if s < i:
  223. s = i # Overlapping! Deal with this better?
  224. if e <= s:
  225. continue
  226. toks.extend(word_tokenize(text[i:s]))
  227. toks.append(Tree("NE", text[s:e].split()))
  228. i = e
  229. toks.extend(word_tokenize(text[i:]))
  230. yield toks
  231. # Multiclass distinction (NE type)
  232. elif fmt == "multiclass":
  233. i = 0
  234. toks = Tree("S", [])
  235. for (s, e, typ) in sorted(entities):
  236. if s < i:
  237. s = i # Overlapping! Deal with this better?
  238. if e <= s:
  239. continue
  240. toks.extend(word_tokenize(text[i:s]))
  241. toks.append(Tree(typ, text[s:e].split()))
  242. i = e
  243. toks.extend(word_tokenize(text[i:]))
  244. yield toks
  245. else:
  246. raise ValueError("bad fmt value")
  247. # This probably belongs in a more general-purpose location (as does
  248. # the parse_to_tagged function).
  249. def cmp_chunks(correct, guessed):
  250. correct = NEChunkParser._parse_to_tagged(correct)
  251. guessed = NEChunkParser._parse_to_tagged(guessed)
  252. ellipsis = False
  253. for (w, ct), (w, gt) in zip(correct, guessed):
  254. if ct == gt == "O":
  255. if not ellipsis:
  256. print(" {:15} {:15} {2}".format(ct, gt, w))
  257. print(" {:15} {:15} {2}".format("...", "...", "..."))
  258. ellipsis = True
  259. else:
  260. ellipsis = False
  261. print(" {:15} {:15} {2}".format(ct, gt, w))
  262. def build_model(fmt="binary"):
  263. print("Loading training data...")
  264. train_paths = [
  265. find("corpora/ace_data/ace.dev"),
  266. find("corpora/ace_data/ace.heldout"),
  267. find("corpora/ace_data/bbn.dev"),
  268. find("corpora/ace_data/muc.dev"),
  269. ]
  270. train_trees = load_ace_data(train_paths, fmt)
  271. train_data = [postag_tree(t) for t in train_trees]
  272. print("Training...")
  273. cp = NEChunkParser(train_data)
  274. del train_data
  275. print("Loading eval data...")
  276. eval_paths = [find("corpora/ace_data/ace.eval")]
  277. eval_trees = load_ace_data(eval_paths, fmt)
  278. eval_data = [postag_tree(t) for t in eval_trees]
  279. print("Evaluating...")
  280. chunkscore = ChunkScore()
  281. for i, correct in enumerate(eval_data):
  282. guess = cp.parse(correct.leaves())
  283. chunkscore.score(correct, guess)
  284. if i < 3:
  285. cmp_chunks(correct, guess)
  286. print(chunkscore)
  287. outfilename = "/tmp/ne_chunker_{0}.pickle".format(fmt)
  288. print("Saving chunker to {0}...".format(outfilename))
  289. with open(outfilename, "wb") as outfile:
  290. pickle.dump(cp, outfile, -1)
  291. return cp
  292. if __name__ == "__main__":
  293. # Make sure that the pickled object has the right class name:
  294. from nltk.chunk.named_entity import build_model
  295. build_model("binary")
  296. build_model("multiclass")