transitionparser.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # Natural Language Toolkit: Arc-Standard and Arc-eager Transition Based Parsers
  2. #
  3. # Author: Long Duong <longdt219@gmail.com>
  4. #
  5. # Copyright (C) 2001-2020 NLTK Project
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. import tempfile
  9. import pickle
  10. from os import remove
  11. from copy import deepcopy
  12. from operator import itemgetter
  13. try:
  14. from numpy import array
  15. from scipy import sparse
  16. from sklearn.datasets import load_svmlight_file
  17. from sklearn import svm
  18. except ImportError:
  19. pass
  20. from nltk.parse import ParserI, DependencyGraph, DependencyEvaluator
  21. class Configuration(object):
  22. """
  23. Class for holding configuration which is the partial analysis of the input sentence.
  24. The transition based parser aims at finding set of operators that transfer the initial
  25. configuration to the terminal configuration.
  26. The configuration includes:
  27. - Stack: for storing partially proceeded words
  28. - Buffer: for storing remaining input words
  29. - Set of arcs: for storing partially built dependency tree
  30. This class also provides a method to represent a configuration as list of features.
  31. """
  32. def __init__(self, dep_graph):
  33. """
  34. :param dep_graph: the representation of an input in the form of dependency graph.
  35. :type dep_graph: DependencyGraph where the dependencies are not specified.
  36. """
  37. # dep_graph.nodes contain list of token for a sentence
  38. self.stack = [0] # The root element
  39. self.buffer = list(range(1, len(dep_graph.nodes))) # The rest is in the buffer
  40. self.arcs = [] # empty set of arc
  41. self._tokens = dep_graph.nodes
  42. self._max_address = len(self.buffer)
  43. def __str__(self):
  44. return (
  45. "Stack : "
  46. + str(self.stack)
  47. + " Buffer : "
  48. + str(self.buffer)
  49. + " Arcs : "
  50. + str(self.arcs)
  51. )
  52. def _check_informative(self, feat, flag=False):
  53. """
  54. Check whether a feature is informative
  55. The flag control whether "_" is informative or not
  56. """
  57. if feat is None:
  58. return False
  59. if feat == "":
  60. return False
  61. if flag is False:
  62. if feat == "_":
  63. return False
  64. return True
  65. def extract_features(self):
  66. """
  67. Extract the set of features for the current configuration. Implement standard features as describe in
  68. Table 3.2 (page 31) in Dependency Parsing book by Sandra Kubler, Ryan McDonal, Joakim Nivre.
  69. Please note that these features are very basic.
  70. :return: list(str)
  71. """
  72. result = []
  73. # Todo : can come up with more complicated features set for better
  74. # performance.
  75. if len(self.stack) > 0:
  76. # Stack 0
  77. stack_idx0 = self.stack[len(self.stack) - 1]
  78. token = self._tokens[stack_idx0]
  79. if self._check_informative(token["word"], True):
  80. result.append("STK_0_FORM_" + token["word"])
  81. if "lemma" in token and self._check_informative(token["lemma"]):
  82. result.append("STK_0_LEMMA_" + token["lemma"])
  83. if self._check_informative(token["tag"]):
  84. result.append("STK_0_POS_" + token["tag"])
  85. if "feats" in token and self._check_informative(token["feats"]):
  86. feats = token["feats"].split("|")
  87. for feat in feats:
  88. result.append("STK_0_FEATS_" + feat)
  89. # Stack 1
  90. if len(self.stack) > 1:
  91. stack_idx1 = self.stack[len(self.stack) - 2]
  92. token = self._tokens[stack_idx1]
  93. if self._check_informative(token["tag"]):
  94. result.append("STK_1_POS_" + token["tag"])
  95. # Left most, right most dependency of stack[0]
  96. left_most = 1000000
  97. right_most = -1
  98. dep_left_most = ""
  99. dep_right_most = ""
  100. for (wi, r, wj) in self.arcs:
  101. if wi == stack_idx0:
  102. if (wj > wi) and (wj > right_most):
  103. right_most = wj
  104. dep_right_most = r
  105. if (wj < wi) and (wj < left_most):
  106. left_most = wj
  107. dep_left_most = r
  108. if self._check_informative(dep_left_most):
  109. result.append("STK_0_LDEP_" + dep_left_most)
  110. if self._check_informative(dep_right_most):
  111. result.append("STK_0_RDEP_" + dep_right_most)
  112. # Check Buffered 0
  113. if len(self.buffer) > 0:
  114. # Buffer 0
  115. buffer_idx0 = self.buffer[0]
  116. token = self._tokens[buffer_idx0]
  117. if self._check_informative(token["word"], True):
  118. result.append("BUF_0_FORM_" + token["word"])
  119. if "lemma" in token and self._check_informative(token["lemma"]):
  120. result.append("BUF_0_LEMMA_" + token["lemma"])
  121. if self._check_informative(token["tag"]):
  122. result.append("BUF_0_POS_" + token["tag"])
  123. if "feats" in token and self._check_informative(token["feats"]):
  124. feats = token["feats"].split("|")
  125. for feat in feats:
  126. result.append("BUF_0_FEATS_" + feat)
  127. # Buffer 1
  128. if len(self.buffer) > 1:
  129. buffer_idx1 = self.buffer[1]
  130. token = self._tokens[buffer_idx1]
  131. if self._check_informative(token["word"], True):
  132. result.append("BUF_1_FORM_" + token["word"])
  133. if self._check_informative(token["tag"]):
  134. result.append("BUF_1_POS_" + token["tag"])
  135. if len(self.buffer) > 2:
  136. buffer_idx2 = self.buffer[2]
  137. token = self._tokens[buffer_idx2]
  138. if self._check_informative(token["tag"]):
  139. result.append("BUF_2_POS_" + token["tag"])
  140. if len(self.buffer) > 3:
  141. buffer_idx3 = self.buffer[3]
  142. token = self._tokens[buffer_idx3]
  143. if self._check_informative(token["tag"]):
  144. result.append("BUF_3_POS_" + token["tag"])
  145. # Left most, right most dependency of stack[0]
  146. left_most = 1000000
  147. right_most = -1
  148. dep_left_most = ""
  149. dep_right_most = ""
  150. for (wi, r, wj) in self.arcs:
  151. if wi == buffer_idx0:
  152. if (wj > wi) and (wj > right_most):
  153. right_most = wj
  154. dep_right_most = r
  155. if (wj < wi) and (wj < left_most):
  156. left_most = wj
  157. dep_left_most = r
  158. if self._check_informative(dep_left_most):
  159. result.append("BUF_0_LDEP_" + dep_left_most)
  160. if self._check_informative(dep_right_most):
  161. result.append("BUF_0_RDEP_" + dep_right_most)
  162. return result
  163. class Transition(object):
  164. """
  165. This class defines a set of transition which is applied to a configuration to get another configuration
  166. Note that for different parsing algorithm, the transition is different.
  167. """
  168. # Define set of transitions
  169. LEFT_ARC = "LEFTARC"
  170. RIGHT_ARC = "RIGHTARC"
  171. SHIFT = "SHIFT"
  172. REDUCE = "REDUCE"
  173. def __init__(self, alg_option):
  174. """
  175. :param alg_option: the algorithm option of this parser. Currently support `arc-standard` and `arc-eager` algorithm
  176. :type alg_option: str
  177. """
  178. self._algo = alg_option
  179. if alg_option not in [
  180. TransitionParser.ARC_STANDARD,
  181. TransitionParser.ARC_EAGER,
  182. ]:
  183. raise ValueError(
  184. " Currently we only support %s and %s "
  185. % (TransitionParser.ARC_STANDARD, TransitionParser.ARC_EAGER)
  186. )
  187. def left_arc(self, conf, relation):
  188. """
  189. Note that the algorithm for left-arc is quite similar except for precondition for both arc-standard and arc-eager
  190. :param configuration: is the current configuration
  191. :return : A new configuration or -1 if the pre-condition is not satisfied
  192. """
  193. if (len(conf.buffer) <= 0) or (len(conf.stack) <= 0):
  194. return -1
  195. if conf.buffer[0] == 0:
  196. # here is the Root element
  197. return -1
  198. idx_wi = conf.stack[len(conf.stack) - 1]
  199. flag = True
  200. if self._algo == TransitionParser.ARC_EAGER:
  201. for (idx_parent, r, idx_child) in conf.arcs:
  202. if idx_child == idx_wi:
  203. flag = False
  204. if flag:
  205. conf.stack.pop()
  206. idx_wj = conf.buffer[0]
  207. conf.arcs.append((idx_wj, relation, idx_wi))
  208. else:
  209. return -1
  210. def right_arc(self, conf, relation):
  211. """
  212. Note that the algorithm for right-arc is DIFFERENT for arc-standard and arc-eager
  213. :param configuration: is the current configuration
  214. :return : A new configuration or -1 if the pre-condition is not satisfied
  215. """
  216. if (len(conf.buffer) <= 0) or (len(conf.stack) <= 0):
  217. return -1
  218. if self._algo == TransitionParser.ARC_STANDARD:
  219. idx_wi = conf.stack.pop()
  220. idx_wj = conf.buffer[0]
  221. conf.buffer[0] = idx_wi
  222. conf.arcs.append((idx_wi, relation, idx_wj))
  223. else: # arc-eager
  224. idx_wi = conf.stack[len(conf.stack) - 1]
  225. idx_wj = conf.buffer.pop(0)
  226. conf.stack.append(idx_wj)
  227. conf.arcs.append((idx_wi, relation, idx_wj))
  228. def reduce(self, conf):
  229. """
  230. Note that the algorithm for reduce is only available for arc-eager
  231. :param configuration: is the current configuration
  232. :return : A new configuration or -1 if the pre-condition is not satisfied
  233. """
  234. if self._algo != TransitionParser.ARC_EAGER:
  235. return -1
  236. if len(conf.stack) <= 0:
  237. return -1
  238. idx_wi = conf.stack[len(conf.stack) - 1]
  239. flag = False
  240. for (idx_parent, r, idx_child) in conf.arcs:
  241. if idx_child == idx_wi:
  242. flag = True
  243. if flag:
  244. conf.stack.pop() # reduce it
  245. else:
  246. return -1
  247. def shift(self, conf):
  248. """
  249. Note that the algorithm for shift is the SAME for arc-standard and arc-eager
  250. :param configuration: is the current configuration
  251. :return : A new configuration or -1 if the pre-condition is not satisfied
  252. """
  253. if len(conf.buffer) <= 0:
  254. return -1
  255. idx_wi = conf.buffer.pop(0)
  256. conf.stack.append(idx_wi)
  257. class TransitionParser(ParserI):
  258. """
  259. Class for transition based parser. Implement 2 algorithms which are "arc-standard" and "arc-eager"
  260. """
  261. ARC_STANDARD = "arc-standard"
  262. ARC_EAGER = "arc-eager"
  263. def __init__(self, algorithm):
  264. """
  265. :param algorithm: the algorithm option of this parser. Currently support `arc-standard` and `arc-eager` algorithm
  266. :type algorithm: str
  267. """
  268. if not (algorithm in [self.ARC_STANDARD, self.ARC_EAGER]):
  269. raise ValueError(
  270. " Currently we only support %s and %s "
  271. % (self.ARC_STANDARD, self.ARC_EAGER)
  272. )
  273. self._algorithm = algorithm
  274. self._dictionary = {}
  275. self._transition = {}
  276. self._match_transition = {}
  277. def _get_dep_relation(self, idx_parent, idx_child, depgraph):
  278. p_node = depgraph.nodes[idx_parent]
  279. c_node = depgraph.nodes[idx_child]
  280. if c_node["word"] is None:
  281. return None # Root word
  282. if c_node["head"] == p_node["address"]:
  283. return c_node["rel"]
  284. else:
  285. return None
  286. def _convert_to_binary_features(self, features):
  287. """
  288. :param features: list of feature string which is needed to convert to binary features
  289. :type features: list(str)
  290. :return : string of binary features in libsvm format which is 'featureID:value' pairs
  291. """
  292. unsorted_result = []
  293. for feature in features:
  294. self._dictionary.setdefault(feature, len(self._dictionary))
  295. unsorted_result.append(self._dictionary[feature])
  296. # Default value of each feature is 1.0
  297. return " ".join(
  298. str(featureID) + ":1.0" for featureID in sorted(unsorted_result)
  299. )
  300. def _is_projective(self, depgraph):
  301. arc_list = []
  302. for key in depgraph.nodes:
  303. node = depgraph.nodes[key]
  304. if "head" in node:
  305. childIdx = node["address"]
  306. parentIdx = node["head"]
  307. if parentIdx is not None:
  308. arc_list.append((parentIdx, childIdx))
  309. for (parentIdx, childIdx) in arc_list:
  310. # Ensure that childIdx < parentIdx
  311. if childIdx > parentIdx:
  312. temp = childIdx
  313. childIdx = parentIdx
  314. parentIdx = temp
  315. for k in range(childIdx + 1, parentIdx):
  316. for m in range(len(depgraph.nodes)):
  317. if (m < childIdx) or (m > parentIdx):
  318. if (k, m) in arc_list:
  319. return False
  320. if (m, k) in arc_list:
  321. return False
  322. return True
  323. def _write_to_file(self, key, binary_features, input_file):
  324. """
  325. write the binary features to input file and update the transition dictionary
  326. """
  327. self._transition.setdefault(key, len(self._transition) + 1)
  328. self._match_transition[self._transition[key]] = key
  329. input_str = str(self._transition[key]) + " " + binary_features + "\n"
  330. input_file.write(input_str.encode("utf-8"))
  331. def _create_training_examples_arc_std(self, depgraphs, input_file):
  332. """
  333. Create the training example in the libsvm format and write it to the input_file.
  334. Reference : Page 32, Chapter 3. Dependency Parsing by Sandra Kubler, Ryan McDonal and Joakim Nivre (2009)
  335. """
  336. operation = Transition(self.ARC_STANDARD)
  337. count_proj = 0
  338. training_seq = []
  339. for depgraph in depgraphs:
  340. if not self._is_projective(depgraph):
  341. continue
  342. count_proj += 1
  343. conf = Configuration(depgraph)
  344. while len(conf.buffer) > 0:
  345. b0 = conf.buffer[0]
  346. features = conf.extract_features()
  347. binary_features = self._convert_to_binary_features(features)
  348. if len(conf.stack) > 0:
  349. s0 = conf.stack[len(conf.stack) - 1]
  350. # Left-arc operation
  351. rel = self._get_dep_relation(b0, s0, depgraph)
  352. if rel is not None:
  353. key = Transition.LEFT_ARC + ":" + rel
  354. self._write_to_file(key, binary_features, input_file)
  355. operation.left_arc(conf, rel)
  356. training_seq.append(key)
  357. continue
  358. # Right-arc operation
  359. rel = self._get_dep_relation(s0, b0, depgraph)
  360. if rel is not None:
  361. precondition = True
  362. # Get the max-index of buffer
  363. maxID = conf._max_address
  364. for w in range(maxID + 1):
  365. if w != b0:
  366. relw = self._get_dep_relation(b0, w, depgraph)
  367. if relw is not None:
  368. if (b0, relw, w) not in conf.arcs:
  369. precondition = False
  370. if precondition:
  371. key = Transition.RIGHT_ARC + ":" + rel
  372. self._write_to_file(key, binary_features, input_file)
  373. operation.right_arc(conf, rel)
  374. training_seq.append(key)
  375. continue
  376. # Shift operation as the default
  377. key = Transition.SHIFT
  378. self._write_to_file(key, binary_features, input_file)
  379. operation.shift(conf)
  380. training_seq.append(key)
  381. print(" Number of training examples : " + str(len(depgraphs)))
  382. print(" Number of valid (projective) examples : " + str(count_proj))
  383. return training_seq
  384. def _create_training_examples_arc_eager(self, depgraphs, input_file):
  385. """
  386. Create the training example in the libsvm format and write it to the input_file.
  387. Reference : 'A Dynamic Oracle for Arc-Eager Dependency Parsing' by Joav Goldberg and Joakim Nivre
  388. """
  389. operation = Transition(self.ARC_EAGER)
  390. countProj = 0
  391. training_seq = []
  392. for depgraph in depgraphs:
  393. if not self._is_projective(depgraph):
  394. continue
  395. countProj += 1
  396. conf = Configuration(depgraph)
  397. while len(conf.buffer) > 0:
  398. b0 = conf.buffer[0]
  399. features = conf.extract_features()
  400. binary_features = self._convert_to_binary_features(features)
  401. if len(conf.stack) > 0:
  402. s0 = conf.stack[len(conf.stack) - 1]
  403. # Left-arc operation
  404. rel = self._get_dep_relation(b0, s0, depgraph)
  405. if rel is not None:
  406. key = Transition.LEFT_ARC + ":" + rel
  407. self._write_to_file(key, binary_features, input_file)
  408. operation.left_arc(conf, rel)
  409. training_seq.append(key)
  410. continue
  411. # Right-arc operation
  412. rel = self._get_dep_relation(s0, b0, depgraph)
  413. if rel is not None:
  414. key = Transition.RIGHT_ARC + ":" + rel
  415. self._write_to_file(key, binary_features, input_file)
  416. operation.right_arc(conf, rel)
  417. training_seq.append(key)
  418. continue
  419. # reduce operation
  420. flag = False
  421. for k in range(s0):
  422. if self._get_dep_relation(k, b0, depgraph) is not None:
  423. flag = True
  424. if self._get_dep_relation(b0, k, depgraph) is not None:
  425. flag = True
  426. if flag:
  427. key = Transition.REDUCE
  428. self._write_to_file(key, binary_features, input_file)
  429. operation.reduce(conf)
  430. training_seq.append(key)
  431. continue
  432. # Shift operation as the default
  433. key = Transition.SHIFT
  434. self._write_to_file(key, binary_features, input_file)
  435. operation.shift(conf)
  436. training_seq.append(key)
  437. print(" Number of training examples : " + str(len(depgraphs)))
  438. print(" Number of valid (projective) examples : " + str(countProj))
  439. return training_seq
  440. def train(self, depgraphs, modelfile, verbose=True):
  441. """
  442. :param depgraphs : list of DependencyGraph as the training data
  443. :type depgraphs : DependencyGraph
  444. :param modelfile : file name to save the trained model
  445. :type modelfile : str
  446. """
  447. try:
  448. input_file = tempfile.NamedTemporaryFile(
  449. prefix="transition_parse.train", dir=tempfile.gettempdir(), delete=False
  450. )
  451. if self._algorithm == self.ARC_STANDARD:
  452. self._create_training_examples_arc_std(depgraphs, input_file)
  453. else:
  454. self._create_training_examples_arc_eager(depgraphs, input_file)
  455. input_file.close()
  456. # Using the temporary file to train the libsvm classifier
  457. x_train, y_train = load_svmlight_file(input_file.name)
  458. # The parameter is set according to the paper:
  459. # Algorithms for Deterministic Incremental Dependency Parsing by Joakim Nivre
  460. # Todo : because of probability = True => very slow due to
  461. # cross-validation. Need to improve the speed here
  462. model = svm.SVC(
  463. kernel="poly",
  464. degree=2,
  465. coef0=0,
  466. gamma=0.2,
  467. C=0.5,
  468. verbose=verbose,
  469. probability=True,
  470. )
  471. model.fit(x_train, y_train)
  472. # Save the model to file name (as pickle)
  473. pickle.dump(model, open(modelfile, "wb"))
  474. finally:
  475. remove(input_file.name)
  476. def parse(self, depgraphs, modelFile):
  477. """
  478. :param depgraphs: the list of test sentence, each sentence is represented as a dependency graph where the 'head' information is dummy
  479. :type depgraphs: list(DependencyGraph)
  480. :param modelfile: the model file
  481. :type modelfile: str
  482. :return: list (DependencyGraph) with the 'head' and 'rel' information
  483. """
  484. result = []
  485. # First load the model
  486. model = pickle.load(open(modelFile, "rb"))
  487. operation = Transition(self._algorithm)
  488. for depgraph in depgraphs:
  489. conf = Configuration(depgraph)
  490. while len(conf.buffer) > 0:
  491. features = conf.extract_features()
  492. col = []
  493. row = []
  494. data = []
  495. for feature in features:
  496. if feature in self._dictionary:
  497. col.append(self._dictionary[feature])
  498. row.append(0)
  499. data.append(1.0)
  500. np_col = array(sorted(col)) # NB : index must be sorted
  501. np_row = array(row)
  502. np_data = array(data)
  503. x_test = sparse.csr_matrix(
  504. (np_data, (np_row, np_col)), shape=(1, len(self._dictionary))
  505. )
  506. # It's best to use decision function as follow BUT it's not supported yet for sparse SVM
  507. # Using decision funcion to build the votes array
  508. # dec_func = model.decision_function(x_test)[0]
  509. # votes = {}
  510. # k = 0
  511. # for i in range(len(model.classes_)):
  512. # for j in range(i+1, len(model.classes_)):
  513. # #if dec_func[k] > 0:
  514. # votes.setdefault(i,0)
  515. # votes[i] +=1
  516. # else:
  517. # votes.setdefault(j,0)
  518. # votes[j] +=1
  519. # k +=1
  520. # Sort votes according to the values
  521. # sorted_votes = sorted(votes.items(), key=itemgetter(1), reverse=True)
  522. # We will use predict_proba instead of decision_function
  523. prob_dict = {}
  524. pred_prob = model.predict_proba(x_test)[0]
  525. for i in range(len(pred_prob)):
  526. prob_dict[i] = pred_prob[i]
  527. sorted_Prob = sorted(prob_dict.items(), key=itemgetter(1), reverse=True)
  528. # Note that SHIFT is always a valid operation
  529. for (y_pred_idx, confidence) in sorted_Prob:
  530. # y_pred = model.predict(x_test)[0]
  531. # From the prediction match to the operation
  532. y_pred = model.classes_[y_pred_idx]
  533. if y_pred in self._match_transition:
  534. strTransition = self._match_transition[y_pred]
  535. baseTransition = strTransition.split(":")[0]
  536. if baseTransition == Transition.LEFT_ARC:
  537. if (
  538. operation.left_arc(conf, strTransition.split(":")[1])
  539. != -1
  540. ):
  541. break
  542. elif baseTransition == Transition.RIGHT_ARC:
  543. if (
  544. operation.right_arc(conf, strTransition.split(":")[1])
  545. != -1
  546. ):
  547. break
  548. elif baseTransition == Transition.REDUCE:
  549. if operation.reduce(conf) != -1:
  550. break
  551. elif baseTransition == Transition.SHIFT:
  552. if operation.shift(conf) != -1:
  553. break
  554. else:
  555. raise ValueError(
  556. "The predicted transition is not recognized, expected errors"
  557. )
  558. # Finish with operations build the dependency graph from Conf.arcs
  559. new_depgraph = deepcopy(depgraph)
  560. for key in new_depgraph.nodes:
  561. node = new_depgraph.nodes[key]
  562. node["rel"] = ""
  563. # With the default, all the token depend on the Root
  564. node["head"] = 0
  565. for (head, rel, child) in conf.arcs:
  566. c_node = new_depgraph.nodes[child]
  567. c_node["head"] = head
  568. c_node["rel"] = rel
  569. result.append(new_depgraph)
  570. return result
  571. def demo():
  572. """
  573. >>> from nltk.parse import DependencyGraph, DependencyEvaluator
  574. >>> from nltk.parse.transitionparser import TransitionParser, Configuration, Transition
  575. >>> gold_sent = DependencyGraph(\"""
  576. ... Economic JJ 2 ATT
  577. ... news NN 3 SBJ
  578. ... has VBD 0 ROOT
  579. ... little JJ 5 ATT
  580. ... effect NN 3 OBJ
  581. ... on IN 5 ATT
  582. ... financial JJ 8 ATT
  583. ... markets NNS 6 PC
  584. ... . . 3 PU
  585. ... \""")
  586. >>> conf = Configuration(gold_sent)
  587. ###################### Check the Initial Feature ########################
  588. >>> print(', '.join(conf.extract_features()))
  589. STK_0_POS_TOP, BUF_0_FORM_Economic, BUF_0_LEMMA_Economic, BUF_0_POS_JJ, BUF_1_FORM_news, BUF_1_POS_NN, BUF_2_POS_VBD, BUF_3_POS_JJ
  590. ###################### Check The Transition #######################
  591. Check the Initialized Configuration
  592. >>> print(conf)
  593. Stack : [0] Buffer : [1, 2, 3, 4, 5, 6, 7, 8, 9] Arcs : []
  594. A. Do some transition checks for ARC-STANDARD
  595. >>> operation = Transition('arc-standard')
  596. >>> operation.shift(conf)
  597. >>> operation.left_arc(conf, "ATT")
  598. >>> operation.shift(conf)
  599. >>> operation.left_arc(conf,"SBJ")
  600. >>> operation.shift(conf)
  601. >>> operation.shift(conf)
  602. >>> operation.left_arc(conf, "ATT")
  603. >>> operation.shift(conf)
  604. >>> operation.shift(conf)
  605. >>> operation.shift(conf)
  606. >>> operation.left_arc(conf, "ATT")
  607. Middle Configuration and Features Check
  608. >>> print(conf)
  609. Stack : [0, 3, 5, 6] Buffer : [8, 9] Arcs : [(2, 'ATT', 1), (3, 'SBJ', 2), (5, 'ATT', 4), (8, 'ATT', 7)]
  610. >>> print(', '.join(conf.extract_features()))
  611. STK_0_FORM_on, STK_0_LEMMA_on, STK_0_POS_IN, STK_1_POS_NN, BUF_0_FORM_markets, BUF_0_LEMMA_markets, BUF_0_POS_NNS, BUF_1_FORM_., BUF_1_POS_., BUF_0_LDEP_ATT
  612. >>> operation.right_arc(conf, "PC")
  613. >>> operation.right_arc(conf, "ATT")
  614. >>> operation.right_arc(conf, "OBJ")
  615. >>> operation.shift(conf)
  616. >>> operation.right_arc(conf, "PU")
  617. >>> operation.right_arc(conf, "ROOT")
  618. >>> operation.shift(conf)
  619. Terminated Configuration Check
  620. >>> print(conf)
  621. Stack : [0] Buffer : [] Arcs : [(2, 'ATT', 1), (3, 'SBJ', 2), (5, 'ATT', 4), (8, 'ATT', 7), (6, 'PC', 8), (5, 'ATT', 6), (3, 'OBJ', 5), (3, 'PU', 9), (0, 'ROOT', 3)]
  622. B. Do some transition checks for ARC-EAGER
  623. >>> conf = Configuration(gold_sent)
  624. >>> operation = Transition('arc-eager')
  625. >>> operation.shift(conf)
  626. >>> operation.left_arc(conf,'ATT')
  627. >>> operation.shift(conf)
  628. >>> operation.left_arc(conf,'SBJ')
  629. >>> operation.right_arc(conf,'ROOT')
  630. >>> operation.shift(conf)
  631. >>> operation.left_arc(conf,'ATT')
  632. >>> operation.right_arc(conf,'OBJ')
  633. >>> operation.right_arc(conf,'ATT')
  634. >>> operation.shift(conf)
  635. >>> operation.left_arc(conf,'ATT')
  636. >>> operation.right_arc(conf,'PC')
  637. >>> operation.reduce(conf)
  638. >>> operation.reduce(conf)
  639. >>> operation.reduce(conf)
  640. >>> operation.right_arc(conf,'PU')
  641. >>> print(conf)
  642. Stack : [0, 3, 9] Buffer : [] Arcs : [(2, 'ATT', 1), (3, 'SBJ', 2), (0, 'ROOT', 3), (5, 'ATT', 4), (3, 'OBJ', 5), (5, 'ATT', 6), (8, 'ATT', 7), (6, 'PC', 8), (3, 'PU', 9)]
  643. ###################### Check The Training Function #######################
  644. A. Check the ARC-STANDARD training
  645. >>> import tempfile
  646. >>> import os
  647. >>> input_file = tempfile.NamedTemporaryFile(prefix='transition_parse.train', dir=tempfile.gettempdir(), delete=False)
  648. >>> parser_std = TransitionParser('arc-standard')
  649. >>> print(', '.join(parser_std._create_training_examples_arc_std([gold_sent], input_file)))
  650. Number of training examples : 1
  651. Number of valid (projective) examples : 1
  652. SHIFT, LEFTARC:ATT, SHIFT, LEFTARC:SBJ, SHIFT, SHIFT, LEFTARC:ATT, SHIFT, SHIFT, SHIFT, LEFTARC:ATT, RIGHTARC:PC, RIGHTARC:ATT, RIGHTARC:OBJ, SHIFT, RIGHTARC:PU, RIGHTARC:ROOT, SHIFT
  653. >>> parser_std.train([gold_sent],'temp.arcstd.model', verbose=False)
  654. Number of training examples : 1
  655. Number of valid (projective) examples : 1
  656. >>> remove(input_file.name)
  657. B. Check the ARC-EAGER training
  658. >>> input_file = tempfile.NamedTemporaryFile(prefix='transition_parse.train', dir=tempfile.gettempdir(),delete=False)
  659. >>> parser_eager = TransitionParser('arc-eager')
  660. >>> print(', '.join(parser_eager._create_training_examples_arc_eager([gold_sent], input_file)))
  661. Number of training examples : 1
  662. Number of valid (projective) examples : 1
  663. SHIFT, LEFTARC:ATT, SHIFT, LEFTARC:SBJ, RIGHTARC:ROOT, SHIFT, LEFTARC:ATT, RIGHTARC:OBJ, RIGHTARC:ATT, SHIFT, LEFTARC:ATT, RIGHTARC:PC, REDUCE, REDUCE, REDUCE, RIGHTARC:PU
  664. >>> parser_eager.train([gold_sent],'temp.arceager.model', verbose=False)
  665. Number of training examples : 1
  666. Number of valid (projective) examples : 1
  667. >>> remove(input_file.name)
  668. ###################### Check The Parsing Function ########################
  669. A. Check the ARC-STANDARD parser
  670. >>> result = parser_std.parse([gold_sent], 'temp.arcstd.model')
  671. >>> de = DependencyEvaluator(result, [gold_sent])
  672. >>> de.eval() >= (0, 0)
  673. True
  674. B. Check the ARC-EAGER parser
  675. >>> result = parser_eager.parse([gold_sent], 'temp.arceager.model')
  676. >>> de = DependencyEvaluator(result, [gold_sent])
  677. >>> de.eval() >= (0, 0)
  678. True
  679. Remove test temporary files
  680. >>> remove('temp.arceager.model')
  681. >>> remove('temp.arcstd.model')
  682. Note that result is very poor because of only one training example.
  683. """