__init__.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """
  2. Thin wrappers around common functions.
  3. Subpackages contain potentially unstable extensions.
  4. """
  5. from tqdm import tqdm
  6. from tqdm.auto import tqdm as tqdm_auto
  7. from tqdm.utils import ObjectWrapper
  8. from functools import wraps
  9. import sys
  10. __author__ = {"github.com/": ["casperdcl"]}
  11. __all__ = ['tenumerate', 'tzip', 'tmap']
  12. class DummyTqdmFile(ObjectWrapper):
  13. """Dummy file-like that will write to tqdm"""
  14. def write(self, x, nolock=False):
  15. # Avoid print() second call (useless \n)
  16. if len(x.rstrip()) > 0:
  17. tqdm.write(x, file=self._wrapped, nolock=nolock)
  18. def builtin_iterable(func):
  19. """Wraps `func()` output in a `list()` in py2"""
  20. if sys.version_info[:1] < (3,):
  21. @wraps(func)
  22. def inner(*args, **kwargs):
  23. return list(func(*args, **kwargs))
  24. return inner
  25. return func
  26. def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto,
  27. **tqdm_kwargs):
  28. """
  29. Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
  30. Parameters
  31. ----------
  32. tqdm_class : [default: tqdm.auto.tqdm].
  33. """
  34. try:
  35. import numpy as np
  36. except ImportError:
  37. pass
  38. else:
  39. if isinstance(iterable, np.ndarray):
  40. return tqdm_class(np.ndenumerate(iterable),
  41. total=total or iterable.size, **tqdm_kwargs)
  42. return enumerate(tqdm_class(iterable, total=total, **tqdm_kwargs), start)
  43. @builtin_iterable
  44. def tzip(iter1, *iter2plus, **tqdm_kwargs):
  45. """
  46. Equivalent of builtin `zip`.
  47. Parameters
  48. ----------
  49. tqdm_class : [default: tqdm.auto.tqdm].
  50. """
  51. kwargs = tqdm_kwargs.copy()
  52. tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
  53. for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
  54. yield i
  55. @builtin_iterable
  56. def tmap(function, *sequences, **tqdm_kwargs):
  57. """
  58. Equivalent of builtin `map`.
  59. Parameters
  60. ----------
  61. tqdm_class : [default: tqdm.auto.tqdm].
  62. """
  63. for i in tzip(*sequences, **tqdm_kwargs):
  64. yield function(*i)