__init__.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import unicode_literals
  2. from itertools import chain
  3. from functools import partial
  4. import lunr
  5. from lunr.builder import Builder
  6. from lunr.languages.trimmer import generate_trimmer
  7. from lunr.languages.stemmer import nltk_stemmer, get_language_stemmer
  8. from lunr.pipeline import Pipeline
  9. from lunr.stop_word_filter import stop_word_filter, generate_stop_word_filter
  10. # map from ISO-639-1 codes to SnowballStemmer.languages
  11. # Languages not supported by nltk but by lunr.js: thai, japanese and turkish
  12. # Languages upported by nltk but not lunr.js: arabic
  13. SUPPORTED_LANGUAGES = {
  14. "ar": "arabic",
  15. "da": "danish",
  16. "nl": "dutch",
  17. "en": "english",
  18. "fi": "finnish",
  19. "fr": "french",
  20. "de": "german",
  21. "hu": "hungarian",
  22. "it": "italian",
  23. "no": "norwegian",
  24. "pt": "portuguese",
  25. "ro": "romanian",
  26. "ru": "russian",
  27. "es": "spanish",
  28. "sv": "swedish",
  29. }
  30. try: # pragma: no cover
  31. import nltk
  32. LANGUAGE_SUPPORT = True
  33. except ImportError: # pragma: no cover
  34. LANGUAGE_SUPPORT = False
  35. def _get_stopwords_and_word_characters(language):
  36. nltk.download("stopwords")
  37. verbose_language = SUPPORTED_LANGUAGES[language]
  38. stopwords = nltk.corpus.stopwords.words(verbose_language)
  39. # TODO: search for a more exhaustive list of word characters
  40. word_characters = {c for word in stopwords for c in word}
  41. return stopwords, word_characters
  42. def get_nltk_builder(languages):
  43. """Returns a builder with stemmers for all languages added to it.
  44. Args:
  45. languages (list): A list of supported languages.
  46. """
  47. all_stemmers = []
  48. all_stopwords_filters = []
  49. all_word_characters = set()
  50. for language in languages:
  51. if language == "en":
  52. # use Lunr's defaults
  53. all_stemmers.append(lunr.stemmer.stemmer)
  54. all_stopwords_filters.append(stop_word_filter)
  55. all_word_characters.update({r"\w"})
  56. else:
  57. stopwords, word_characters = _get_stopwords_and_word_characters(language)
  58. all_stemmers.append(
  59. Pipeline.registered_functions["stemmer-{}".format(language)]
  60. )
  61. all_stopwords_filters.append(
  62. generate_stop_word_filter(stopwords, language=language)
  63. )
  64. all_word_characters.update(word_characters)
  65. builder = Builder()
  66. multi_trimmer = generate_trimmer("".join(sorted(all_word_characters)))
  67. Pipeline.register_function(
  68. multi_trimmer, "lunr-multi-trimmer-{}".format("-".join(languages))
  69. )
  70. builder.pipeline.reset()
  71. for fn in chain([multi_trimmer], all_stopwords_filters, all_stemmers):
  72. builder.pipeline.add(fn)
  73. for fn in all_stemmers:
  74. builder.search_pipeline.add(fn)
  75. return builder
  76. def register_languages():
  77. """Register all supported languages to ensure compatibility."""
  78. for language in set(SUPPORTED_LANGUAGES) - {"en"}:
  79. language_stemmer = partial(nltk_stemmer, get_language_stemmer(language))
  80. Pipeline.register_function(language_stemmer, "stemmer-{}".format(language))
  81. if LANGUAGE_SUPPORT: # pragma: no cover
  82. # TODO: registering all possible stemmers feels unnecessary but it solves
  83. # deserializing with arbitrary language functions. Ideally the schema would
  84. # provide the language(s) for the index and we could register the stemmers
  85. # as needed
  86. register_languages()