utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. """
  2. General helpers required for `tqdm.std`.
  3. """
  4. from functools import wraps
  5. import os
  6. from platform import system as _curos
  7. import re
  8. import subprocess
  9. from warnings import warn
  10. CUR_OS = _curos()
  11. IS_WIN = CUR_OS in ['Windows', 'cli']
  12. IS_NIX = (not IS_WIN) and any(
  13. CUR_OS.startswith(i) for i in
  14. ['CYGWIN', 'MSYS', 'Linux', 'Darwin', 'SunOS',
  15. 'FreeBSD', 'NetBSD', 'OpenBSD'])
  16. RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]")
  17. # Py2/3 compat. Empty conditional to avoid coverage
  18. if True: # pragma: no cover
  19. try:
  20. _range = xrange
  21. except NameError:
  22. _range = range
  23. try:
  24. _unich = unichr
  25. except NameError:
  26. _unich = chr
  27. try:
  28. _unicode = unicode
  29. except NameError:
  30. _unicode = str
  31. try:
  32. if IS_WIN:
  33. import colorama
  34. else:
  35. raise ImportError
  36. except ImportError:
  37. colorama = None
  38. else:
  39. try:
  40. colorama.init(strip=False)
  41. except TypeError:
  42. colorama.init()
  43. try:
  44. from weakref import WeakSet
  45. except ImportError:
  46. WeakSet = set
  47. try:
  48. _basestring = basestring
  49. except NameError:
  50. _basestring = str
  51. try: # py>=2.7,>=3.1
  52. from collections import OrderedDict as _OrderedDict
  53. except ImportError:
  54. try: # older Python versions with backported ordereddict lib
  55. from ordereddict import OrderedDict as _OrderedDict
  56. except ImportError: # older Python versions without ordereddict lib
  57. # Py2.6,3.0 compat, from PEP 372
  58. from collections import MutableMapping
  59. class _OrderedDict(dict, MutableMapping):
  60. # Methods with direct access to underlying attributes
  61. def __init__(self, *args, **kwds):
  62. if len(args) > 1:
  63. raise TypeError('expected at 1 argument, got %d',
  64. len(args))
  65. if not hasattr(self, '_keys'):
  66. self._keys = []
  67. self.update(*args, **kwds)
  68. def clear(self):
  69. del self._keys[:]
  70. dict.clear(self)
  71. def __setitem__(self, key, value):
  72. if key not in self:
  73. self._keys.append(key)
  74. dict.__setitem__(self, key, value)
  75. def __delitem__(self, key):
  76. dict.__delitem__(self, key)
  77. self._keys.remove(key)
  78. def __iter__(self):
  79. return iter(self._keys)
  80. def __reversed__(self):
  81. return reversed(self._keys)
  82. def popitem(self):
  83. if not self:
  84. raise KeyError
  85. key = self._keys.pop()
  86. value = dict.pop(self, key)
  87. return key, value
  88. def __reduce__(self):
  89. items = [[k, self[k]] for k in self]
  90. inst_dict = vars(self).copy()
  91. inst_dict.pop('_keys', None)
  92. return self.__class__, (items,), inst_dict
  93. # Methods with indirect access via the above methods
  94. setdefault = MutableMapping.setdefault
  95. update = MutableMapping.update
  96. pop = MutableMapping.pop
  97. keys = MutableMapping.keys
  98. values = MutableMapping.values
  99. items = MutableMapping.items
  100. def __repr__(self):
  101. pairs = ', '.join(map('%r: %r'.__mod__, self.items()))
  102. return '%s({%s})' % (self.__class__.__name__, pairs)
  103. def copy(self):
  104. return self.__class__(self)
  105. @classmethod
  106. def fromkeys(cls, iterable, value=None):
  107. d = cls()
  108. for key in iterable:
  109. d[key] = value
  110. return d
  111. class FormatReplace(object):
  112. """
  113. >>> a = FormatReplace('something')
  114. >>> "{:5d}".format(a)
  115. 'something'
  116. """
  117. def __init__(self, replace=''):
  118. self.replace = replace
  119. self.format_called = 0
  120. def __format__(self, _):
  121. self.format_called += 1
  122. return self.replace
  123. class Comparable(object):
  124. """Assumes child has self._comparable attr/@property"""
  125. def __lt__(self, other):
  126. return self._comparable < other._comparable
  127. def __le__(self, other):
  128. return (self < other) or (self == other)
  129. def __eq__(self, other):
  130. return self._comparable == other._comparable
  131. def __ne__(self, other):
  132. return not self == other
  133. def __gt__(self, other):
  134. return not self <= other
  135. def __ge__(self, other):
  136. return not self < other
  137. class ObjectWrapper(object):
  138. def __getattr__(self, name):
  139. return getattr(self._wrapped, name)
  140. def __setattr__(self, name, value):
  141. return setattr(self._wrapped, name, value)
  142. def wrapper_getattr(self, name):
  143. """Actual `self.getattr` rather than self._wrapped.getattr"""
  144. try:
  145. return object.__getattr__(self, name)
  146. except AttributeError: # py2
  147. return getattr(self, name)
  148. def wrapper_setattr(self, name, value):
  149. """Actual `self.setattr` rather than self._wrapped.setattr"""
  150. return object.__setattr__(self, name, value)
  151. def __init__(self, wrapped):
  152. """
  153. Thin wrapper around a given object
  154. """
  155. self.wrapper_setattr('_wrapped', wrapped)
  156. class SimpleTextIOWrapper(ObjectWrapper):
  157. """
  158. Change only `.write()` of the wrapped object by encoding the passed
  159. value and passing the result to the wrapped object's `.write()` method.
  160. """
  161. # pylint: disable=too-few-public-methods
  162. def __init__(self, wrapped, encoding):
  163. super(SimpleTextIOWrapper, self).__init__(wrapped)
  164. self.wrapper_setattr('encoding', encoding)
  165. def write(self, s):
  166. """
  167. Encode `s` and pass to the wrapped object's `.write()` method.
  168. """
  169. return self._wrapped.write(s.encode(self.wrapper_getattr('encoding')))
  170. def __eq__(self, other):
  171. return self._wrapped == getattr(other, '_wrapped', other)
  172. class CallbackIOWrapper(ObjectWrapper):
  173. def __init__(self, callback, stream, method="read"):
  174. """
  175. Wrap a given `file`-like object's `read()` or `write()` to report
  176. lengths to the given `callback`
  177. """
  178. super(CallbackIOWrapper, self).__init__(stream)
  179. func = getattr(stream, method)
  180. if method == "write":
  181. @wraps(func)
  182. def write(data, *args, **kwargs):
  183. res = func(data, *args, **kwargs)
  184. callback(len(data))
  185. return res
  186. self.wrapper_setattr('write', write)
  187. elif method == "read":
  188. @wraps(func)
  189. def read(*args, **kwargs):
  190. data = func(*args, **kwargs)
  191. callback(len(data))
  192. return data
  193. self.wrapper_setattr('read', read)
  194. else:
  195. raise KeyError("Can only wrap read/write methods")
  196. def _is_utf(encoding):
  197. try:
  198. u'\u2588\u2589'.encode(encoding)
  199. except UnicodeEncodeError: # pragma: no cover
  200. return False
  201. except Exception: # pragma: no cover
  202. try:
  203. return encoding.lower().startswith('utf-') or ('U8' == encoding)
  204. except:
  205. return False
  206. else:
  207. return True
  208. def _supports_unicode(fp):
  209. try:
  210. return _is_utf(fp.encoding)
  211. except AttributeError:
  212. return False
  213. def _is_ascii(s):
  214. if isinstance(s, str):
  215. for c in s:
  216. if ord(c) > 255:
  217. return False
  218. return True
  219. return _supports_unicode(s)
  220. def _screen_shape_wrapper(): # pragma: no cover
  221. """
  222. Return a function which returns console dimensions (width, height).
  223. Supported: linux, osx, windows, cygwin.
  224. """
  225. _screen_shape = None
  226. if IS_WIN:
  227. _screen_shape = _screen_shape_windows
  228. if _screen_shape is None:
  229. _screen_shape = _screen_shape_tput
  230. if IS_NIX:
  231. _screen_shape = _screen_shape_linux
  232. return _screen_shape
  233. def _screen_shape_windows(fp): # pragma: no cover
  234. try:
  235. from ctypes import windll, create_string_buffer
  236. import struct
  237. from sys import stdin, stdout
  238. io_handle = -12 # assume stderr
  239. if fp == stdin:
  240. io_handle = -10
  241. elif fp == stdout:
  242. io_handle = -11
  243. h = windll.kernel32.GetStdHandle(io_handle)
  244. csbi = create_string_buffer(22)
  245. res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
  246. if res:
  247. (_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom,
  248. _maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
  249. return right - left, bottom - top # +1
  250. except:
  251. pass
  252. return None, None
  253. def _screen_shape_tput(*_): # pragma: no cover
  254. """cygwin xterm (windows)"""
  255. try:
  256. import shlex
  257. return [int(subprocess.check_call(shlex.split('tput ' + i))) - 1
  258. for i in ('cols', 'lines')]
  259. except:
  260. pass
  261. return None, None
  262. def _screen_shape_linux(fp): # pragma: no cover
  263. try:
  264. from termios import TIOCGWINSZ
  265. from fcntl import ioctl
  266. from array import array
  267. except ImportError:
  268. return None
  269. else:
  270. try:
  271. rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2]
  272. return cols, rows
  273. except:
  274. try:
  275. return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")]
  276. except KeyError:
  277. return None, None
  278. def _environ_cols_wrapper(): # pragma: no cover
  279. """
  280. Return a function which returns console width.
  281. Supported: linux, osx, windows, cygwin.
  282. """
  283. warn("Use `_screen_shape_wrapper()(file)[0]` instead of"
  284. " `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2)
  285. shape = _screen_shape_wrapper()
  286. if not shape:
  287. return None
  288. @wraps(shape)
  289. def inner(fp):
  290. return shape(fp)[0]
  291. return inner
  292. def _term_move_up(): # pragma: no cover
  293. return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A'
  294. try:
  295. # TODO consider using wcswidth third-party package for 0-width characters
  296. from unicodedata import east_asian_width
  297. except ImportError:
  298. _text_width = len
  299. else:
  300. def _text_width(s):
  301. return sum(
  302. 2 if east_asian_width(ch) in 'FW' else 1 for ch in _unicode(s))
  303. def disp_len(data):
  304. """
  305. Returns the real on-screen length of a string which may contain
  306. ANSI control codes and wide chars.
  307. """
  308. return _text_width(RE_ANSI.sub('', data))
  309. def disp_trim(data, length):
  310. """
  311. Trim a string which may contain ANSI control characters.
  312. """
  313. if len(data) == disp_len(data):
  314. return data[:length]
  315. ansi_present = bool(RE_ANSI.search(data))
  316. while disp_len(data) > length: # carefully delete one char at a time
  317. data = data[:-1]
  318. if ansi_present and bool(RE_ANSI.search(data)):
  319. # assume ANSI reset is required
  320. return data if data.endswith("\033[0m") else data + "\033[0m"
  321. return data