test_nist.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests for NIST translation evaluation metric
  4. """
  5. import io
  6. import unittest
  7. from nltk.data import find
  8. from nltk.translate.nist_score import sentence_nist, corpus_nist
  9. class TestNIST(unittest.TestCase):
  10. def test_sentence_nist(self):
  11. ref_file = find('models/wmt15_eval/ref.ru')
  12. hyp_file = find('models/wmt15_eval/google.ru')
  13. mteval_output_file = find('models/wmt15_eval/mteval-13a.output')
  14. # Reads the NIST scores from the `mteval-13a.output` file.
  15. # The order of the list corresponds to the order of the ngrams.
  16. with open(mteval_output_file, 'r') as mteval_fin:
  17. # The numbers are located in the last 4th line of the file.
  18. # The first and 2nd item in the list are the score and system names.
  19. mteval_nist_scores = map(float, mteval_fin.readlines()[-4].split()[1:-1])
  20. with io.open(ref_file, 'r', encoding='utf8') as ref_fin:
  21. with io.open(hyp_file, 'r', encoding='utf8') as hyp_fin:
  22. # Whitespace tokenize the file.
  23. # Note: split() automatically strip().
  24. hypotheses = list(map(lambda x: x.split(), hyp_fin))
  25. # Note that the corpus_bleu input is list of list of references.
  26. references = list(map(lambda x: [x.split()], ref_fin))
  27. # Without smoothing.
  28. for i, mteval_nist in zip(range(1, 10), mteval_nist_scores):
  29. nltk_nist = corpus_nist(references, hypotheses, i)
  30. # Check that the NIST scores difference is less than 0.5
  31. assert abs(mteval_nist - nltk_nist) < 0.05