concurrent.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """
  2. Thin wrappers around `concurrent.futures`.
  3. """
  4. from __future__ import absolute_import
  5. from tqdm import TqdmWarning
  6. from tqdm.auto import tqdm as tqdm_auto
  7. try:
  8. from operator import length_hint
  9. except ImportError:
  10. def length_hint(it, default=0):
  11. """Returns `len(it)`, falling back to `default`"""
  12. try:
  13. return len(it)
  14. except TypeError:
  15. return default
  16. try:
  17. from os import cpu_count
  18. except ImportError:
  19. try:
  20. from multiprocessing import cpu_count
  21. except ImportError:
  22. def cpu_count():
  23. return 4
  24. import sys
  25. __author__ = {"github.com/": ["casperdcl"]}
  26. __all__ = ['thread_map', 'process_map']
  27. def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
  28. """
  29. Implementation of `thread_map` and `process_map`.
  30. Parameters
  31. ----------
  32. tqdm_class : [default: tqdm.auto.tqdm].
  33. max_workers : [default: min(32, cpu_count() + 4)].
  34. chunksize : [default: 1].
  35. """
  36. kwargs = tqdm_kwargs.copy()
  37. if "total" not in kwargs:
  38. kwargs["total"] = len(iterables[0])
  39. tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
  40. max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
  41. chunksize = kwargs.pop("chunksize", 1)
  42. pool_kwargs = dict(max_workers=max_workers)
  43. sys_version = sys.version_info[:2]
  44. if sys_version >= (3, 7):
  45. # share lock in case workers are already using `tqdm`
  46. pool_kwargs.update(
  47. initializer=tqdm_class.set_lock, initargs=(tqdm_class.get_lock(),))
  48. map_args = {}
  49. if not (3, 0) < sys_version < (3, 5):
  50. map_args.update(chunksize=chunksize)
  51. with PoolExecutor(**pool_kwargs) as ex:
  52. return list(tqdm_class(
  53. ex.map(fn, *iterables, **map_args), **kwargs))
  54. def thread_map(fn, *iterables, **tqdm_kwargs):
  55. """
  56. Equivalent of `list(map(fn, *iterables))`
  57. driven by `concurrent.futures.ThreadPoolExecutor`.
  58. Parameters
  59. ----------
  60. tqdm_class : optional
  61. `tqdm` class to use for bars [default: tqdm.auto.tqdm].
  62. max_workers : int, optional
  63. Maximum number of workers to spawn; passed to
  64. `concurrent.futures.ThreadPoolExecutor.__init__`.
  65. [default: max(32, cpu_count() + 4)].
  66. """
  67. from concurrent.futures import ThreadPoolExecutor
  68. return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
  69. def process_map(fn, *iterables, **tqdm_kwargs):
  70. """
  71. Equivalent of `list(map(fn, *iterables))`
  72. driven by `concurrent.futures.ProcessPoolExecutor`.
  73. Parameters
  74. ----------
  75. tqdm_class : optional
  76. `tqdm` class to use for bars [default: tqdm.auto.tqdm].
  77. max_workers : int, optional
  78. Maximum number of workers to spawn; passed to
  79. `concurrent.futures.ProcessPoolExecutor.__init__`.
  80. [default: min(32, cpu_count() + 4)].
  81. chunksize : int, optional
  82. Size of chunks sent to worker processes; passed to
  83. `concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
  84. """
  85. from concurrent.futures import ProcessPoolExecutor
  86. if iterables and "chunksize" not in tqdm_kwargs:
  87. # default `chunksize=1` has poor performance for large iterables
  88. # (most time spent dispatching items to workers).
  89. longest_iterable_len = max(map(length_hint, iterables))
  90. if longest_iterable_len > 1000:
  91. from warnings import warn
  92. warn("Iterable length %d > 1000 but `chunksize` is not set."
  93. " This may seriously degrade multiprocess performance."
  94. " Set `chunksize=1` or more." % longest_iterable_len,
  95. TqdmWarning, stacklevel=2)
  96. return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)