evaluate.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Natural Language Toolkit: evaluation of dependency parser
  2. #
  3. # Author: Long Duong <longdt219@gmail.com>
  4. #
  5. # Copyright (C) 2001-2020 NLTK Project
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. import unicodedata
  9. class DependencyEvaluator(object):
  10. """
  11. Class for measuring labelled and unlabelled attachment score for
  12. dependency parsing. Note that the evaluation ignores punctuation.
  13. >>> from nltk.parse import DependencyGraph, DependencyEvaluator
  14. >>> gold_sent = DependencyGraph(\"""
  15. ... Pierre NNP 2 NMOD
  16. ... Vinken NNP 8 SUB
  17. ... , , 2 P
  18. ... 61 CD 5 NMOD
  19. ... years NNS 6 AMOD
  20. ... old JJ 2 NMOD
  21. ... , , 2 P
  22. ... will MD 0 ROOT
  23. ... join VB 8 VC
  24. ... the DT 11 NMOD
  25. ... board NN 9 OBJ
  26. ... as IN 9 VMOD
  27. ... a DT 15 NMOD
  28. ... nonexecutive JJ 15 NMOD
  29. ... director NN 12 PMOD
  30. ... Nov. NNP 9 VMOD
  31. ... 29 CD 16 NMOD
  32. ... . . 9 VMOD
  33. ... \""")
  34. >>> parsed_sent = DependencyGraph(\"""
  35. ... Pierre NNP 8 NMOD
  36. ... Vinken NNP 1 SUB
  37. ... , , 3 P
  38. ... 61 CD 6 NMOD
  39. ... years NNS 6 AMOD
  40. ... old JJ 2 NMOD
  41. ... , , 3 AMOD
  42. ... will MD 0 ROOT
  43. ... join VB 8 VC
  44. ... the DT 11 AMOD
  45. ... board NN 9 OBJECT
  46. ... as IN 9 NMOD
  47. ... a DT 15 NMOD
  48. ... nonexecutive JJ 15 NMOD
  49. ... director NN 12 PMOD
  50. ... Nov. NNP 9 VMOD
  51. ... 29 CD 16 NMOD
  52. ... . . 9 VMOD
  53. ... \""")
  54. >>> de = DependencyEvaluator([parsed_sent],[gold_sent])
  55. >>> las, uas = de.eval()
  56. >>> las
  57. 0.6...
  58. >>> uas
  59. 0.8...
  60. >>> abs(uas - 0.8) < 0.00001
  61. True
  62. """
  63. def __init__(self, parsed_sents, gold_sents):
  64. """
  65. :param parsed_sents: the list of parsed_sents as the output of parser
  66. :type parsed_sents: list(DependencyGraph)
  67. """
  68. self._parsed_sents = parsed_sents
  69. self._gold_sents = gold_sents
  70. def _remove_punct(self, inStr):
  71. """
  72. Function to remove punctuation from Unicode string.
  73. :param input: the input string
  74. :return: Unicode string after remove all punctuation
  75. """
  76. punc_cat = set(["Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"])
  77. return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)
  78. def eval(self):
  79. """
  80. Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
  81. :return : tuple(float,float)
  82. """
  83. if len(self._parsed_sents) != len(self._gold_sents):
  84. raise ValueError(
  85. " Number of parsed sentence is different with number of gold sentence."
  86. )
  87. corr = 0
  88. corrL = 0
  89. total = 0
  90. for i in range(len(self._parsed_sents)):
  91. parsed_sent_nodes = self._parsed_sents[i].nodes
  92. gold_sent_nodes = self._gold_sents[i].nodes
  93. if len(parsed_sent_nodes) != len(gold_sent_nodes):
  94. raise ValueError("Sentences must have equal length.")
  95. for parsed_node_address, parsed_node in parsed_sent_nodes.items():
  96. gold_node = gold_sent_nodes[parsed_node_address]
  97. if parsed_node["word"] is None:
  98. continue
  99. if parsed_node["word"] != gold_node["word"]:
  100. raise ValueError("Sentence sequence is not matched.")
  101. # Ignore if word is punctuation by default
  102. # if (parsed_sent[j]["word"] in string.punctuation):
  103. if self._remove_punct(parsed_node["word"]) == "":
  104. continue
  105. total += 1
  106. if parsed_node["head"] == gold_node["head"]:
  107. corr += 1
  108. if parsed_node["rel"] == gold_node["rel"]:
  109. corrL += 1
  110. return corrL / total, corr / total