func_inspect.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """
  2. My own variation on function-specific inspect-like features.
  3. """
  4. # Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
  5. # Copyright (c) 2009 Gael Varoquaux
  6. # License: BSD Style, 3 clauses.
  7. import inspect
  8. import warnings
  9. import re
  10. import os
  11. import collections
  12. from itertools import islice
  13. from tokenize import open as open_py_source
  14. from .logger import pformat
  15. full_argspec_fields = ('args varargs varkw defaults kwonlyargs '
  16. 'kwonlydefaults annotations')
  17. full_argspec_type = collections.namedtuple('FullArgSpec', full_argspec_fields)
  18. def get_func_code(func):
  19. """ Attempts to retrieve a reliable function code hash.
  20. The reason we don't use inspect.getsource is that it caches the
  21. source, whereas we want this to be modified on the fly when the
  22. function is modified.
  23. Returns
  24. -------
  25. func_code: string
  26. The function code
  27. source_file: string
  28. The path to the file in which the function is defined.
  29. first_line: int
  30. The first line of the code in the source file.
  31. Notes
  32. ------
  33. This function does a bit more magic than inspect, and is thus
  34. more robust.
  35. """
  36. source_file = None
  37. try:
  38. code = func.__code__
  39. source_file = code.co_filename
  40. if not os.path.exists(source_file):
  41. # Use inspect for lambda functions and functions defined in an
  42. # interactive shell, or in doctests
  43. source_code = ''.join(inspect.getsourcelines(func)[0])
  44. line_no = 1
  45. if source_file.startswith('<doctest '):
  46. source_file, line_no = re.match(
  47. r'\<doctest (.*\.rst)\[(.*)\]\>', source_file).groups()
  48. line_no = int(line_no)
  49. source_file = '<doctest %s>' % source_file
  50. return source_code, source_file, line_no
  51. # Try to retrieve the source code.
  52. with open_py_source(source_file) as source_file_obj:
  53. first_line = code.co_firstlineno
  54. # All the lines after the function definition:
  55. source_lines = list(islice(source_file_obj, first_line - 1, None))
  56. return ''.join(inspect.getblock(source_lines)), source_file, first_line
  57. except:
  58. # If the source code fails, we use the hash. This is fragile and
  59. # might change from one session to another.
  60. if hasattr(func, '__code__'):
  61. # Python 3.X
  62. return str(func.__code__.__hash__()), source_file, -1
  63. else:
  64. # Weird objects like numpy ufunc don't have __code__
  65. # This is fragile, as quite often the id of the object is
  66. # in the repr, so it might not persist across sessions,
  67. # however it will work for ufuncs.
  68. return repr(func), source_file, -1
  69. def _clean_win_chars(string):
  70. """Windows cannot encode some characters in filename."""
  71. import urllib
  72. if hasattr(urllib, 'quote'):
  73. quote = urllib.quote
  74. else:
  75. # In Python 3, quote is elsewhere
  76. import urllib.parse
  77. quote = urllib.parse.quote
  78. for char in ('<', '>', '!', ':', '\\'):
  79. string = string.replace(char, quote(char))
  80. return string
  81. def get_func_name(func, resolv_alias=True, win_characters=True):
  82. """ Return the function import path (as a list of module names), and
  83. a name for the function.
  84. Parameters
  85. ----------
  86. func: callable
  87. The func to inspect
  88. resolv_alias: boolean, optional
  89. If true, possible local aliases are indicated.
  90. win_characters: boolean, optional
  91. If true, substitute special characters using urllib.quote
  92. This is useful in Windows, as it cannot encode some filenames
  93. """
  94. if hasattr(func, '__module__'):
  95. module = func.__module__
  96. else:
  97. try:
  98. module = inspect.getmodule(func)
  99. except TypeError:
  100. if hasattr(func, '__class__'):
  101. module = func.__class__.__module__
  102. else:
  103. module = 'unknown'
  104. if module is None:
  105. # Happens in doctests, eg
  106. module = ''
  107. if module == '__main__':
  108. try:
  109. filename = os.path.abspath(inspect.getsourcefile(func))
  110. except:
  111. filename = None
  112. if filename is not None:
  113. # mangling of full path to filename
  114. parts = filename.split(os.sep)
  115. if parts[-1].startswith('<ipython-input'):
  116. # function is defined in an IPython session. The filename
  117. # will change with every new kernel instance. This hack
  118. # always returns the same filename
  119. parts[-1] = '__ipython-input__'
  120. filename = '-'.join(parts)
  121. if filename.endswith('.py'):
  122. filename = filename[:-3]
  123. module = module + '-' + filename
  124. module = module.split('.')
  125. if hasattr(func, 'func_name'):
  126. name = func.func_name
  127. elif hasattr(func, '__name__'):
  128. name = func.__name__
  129. else:
  130. name = 'unknown'
  131. # Hack to detect functions not defined at the module-level
  132. if resolv_alias:
  133. # TODO: Maybe add a warning here?
  134. if hasattr(func, 'func_globals') and name in func.func_globals:
  135. if not func.func_globals[name] is func:
  136. name = '%s-alias' % name
  137. if inspect.ismethod(func):
  138. # We need to add the name of the class
  139. if hasattr(func, 'im_class'):
  140. klass = func.im_class
  141. module.append(klass.__name__)
  142. if os.name == 'nt' and win_characters:
  143. # Stupid windows can't encode certain characters in filenames
  144. name = _clean_win_chars(name)
  145. module = [_clean_win_chars(s) for s in module]
  146. return module, name
  147. def _signature_str(function_name, arg_spec):
  148. """Helper function to output a function signature"""
  149. arg_spec_str = inspect.formatargspec(*arg_spec)
  150. return '{}{}'.format(function_name, arg_spec_str)
  151. def _function_called_str(function_name, args, kwargs):
  152. """Helper function to output a function call"""
  153. template_str = '{0}({1}, {2})'
  154. args_str = repr(args)[1:-1]
  155. kwargs_str = ', '.join('%s=%s' % (k, v)
  156. for k, v in kwargs.items())
  157. return template_str.format(function_name, args_str,
  158. kwargs_str)
  159. def filter_args(func, ignore_lst, args=(), kwargs=dict()):
  160. """ Filters the given args and kwargs using a list of arguments to
  161. ignore, and a function specification.
  162. Parameters
  163. ----------
  164. func: callable
  165. Function giving the argument specification
  166. ignore_lst: list of strings
  167. List of arguments to ignore (either a name of an argument
  168. in the function spec, or '*', or '**')
  169. *args: list
  170. Positional arguments passed to the function.
  171. **kwargs: dict
  172. Keyword arguments passed to the function
  173. Returns
  174. -------
  175. filtered_args: list
  176. List of filtered positional and keyword arguments.
  177. """
  178. args = list(args)
  179. if isinstance(ignore_lst, str):
  180. # Catch a common mistake
  181. raise ValueError(
  182. 'ignore_lst must be a list of parameters to ignore '
  183. '%s (type %s) was given' % (ignore_lst, type(ignore_lst)))
  184. # Special case for functools.partial objects
  185. if (not inspect.ismethod(func) and not inspect.isfunction(func)):
  186. if ignore_lst:
  187. warnings.warn('Cannot inspect object %s, ignore list will '
  188. 'not work.' % func, stacklevel=2)
  189. return {'*': args, '**': kwargs}
  190. arg_spec = inspect.getfullargspec(func)
  191. arg_names = arg_spec.args + arg_spec.kwonlyargs
  192. arg_defaults = arg_spec.defaults or ()
  193. if arg_spec.kwonlydefaults:
  194. arg_defaults = arg_defaults + tuple(arg_spec.kwonlydefaults[k]
  195. for k in arg_spec.kwonlyargs
  196. if k in arg_spec.kwonlydefaults)
  197. arg_varargs = arg_spec.varargs
  198. arg_varkw = arg_spec.varkw
  199. if inspect.ismethod(func):
  200. # First argument is 'self', it has been removed by Python
  201. # we need to add it back:
  202. args = [func.__self__, ] + args
  203. # XXX: Maybe I need an inspect.isbuiltin to detect C-level methods, such
  204. # as on ndarrays.
  205. _, name = get_func_name(func, resolv_alias=False)
  206. arg_dict = dict()
  207. arg_position = -1
  208. for arg_position, arg_name in enumerate(arg_names):
  209. if arg_position < len(args):
  210. # Positional argument or keyword argument given as positional
  211. if arg_name not in arg_spec.kwonlyargs:
  212. arg_dict[arg_name] = args[arg_position]
  213. else:
  214. raise ValueError(
  215. "Keyword-only parameter '%s' was passed as "
  216. 'positional parameter for %s:\n'
  217. ' %s was called.'
  218. % (arg_name,
  219. _signature_str(name, arg_spec),
  220. _function_called_str(name, args, kwargs))
  221. )
  222. else:
  223. position = arg_position - len(arg_names)
  224. if arg_name in kwargs:
  225. arg_dict[arg_name] = kwargs[arg_name]
  226. else:
  227. try:
  228. arg_dict[arg_name] = arg_defaults[position]
  229. except (IndexError, KeyError) as e:
  230. # Missing argument
  231. raise ValueError(
  232. 'Wrong number of arguments for %s:\n'
  233. ' %s was called.'
  234. % (_signature_str(name, arg_spec),
  235. _function_called_str(name, args, kwargs))
  236. ) from e
  237. varkwargs = dict()
  238. for arg_name, arg_value in sorted(kwargs.items()):
  239. if arg_name in arg_dict:
  240. arg_dict[arg_name] = arg_value
  241. elif arg_varkw is not None:
  242. varkwargs[arg_name] = arg_value
  243. else:
  244. raise TypeError("Ignore list for %s() contains an unexpected "
  245. "keyword argument '%s'" % (name, arg_name))
  246. if arg_varkw is not None:
  247. arg_dict['**'] = varkwargs
  248. if arg_varargs is not None:
  249. varargs = args[arg_position + 1:]
  250. arg_dict['*'] = varargs
  251. # Now remove the arguments to be ignored
  252. for item in ignore_lst:
  253. if item in arg_dict:
  254. arg_dict.pop(item)
  255. else:
  256. raise ValueError("Ignore list: argument '%s' is not defined for "
  257. "function %s"
  258. % (item,
  259. _signature_str(name, arg_spec))
  260. )
  261. # XXX: Return a sorted list of pairs?
  262. return arg_dict
  263. def _format_arg(arg):
  264. formatted_arg = pformat(arg, indent=2)
  265. if len(formatted_arg) > 1500:
  266. formatted_arg = '%s...' % formatted_arg[:700]
  267. return formatted_arg
  268. def format_signature(func, *args, **kwargs):
  269. # XXX: Should this use inspect.formatargvalues/formatargspec?
  270. module, name = get_func_name(func)
  271. module = [m for m in module if m]
  272. if module:
  273. module.append(name)
  274. module_path = '.'.join(module)
  275. else:
  276. module_path = name
  277. arg_str = list()
  278. previous_length = 0
  279. for arg in args:
  280. formatted_arg = _format_arg(arg)
  281. if previous_length > 80:
  282. formatted_arg = '\n%s' % formatted_arg
  283. previous_length = len(formatted_arg)
  284. arg_str.append(formatted_arg)
  285. arg_str.extend(['%s=%s' % (v, _format_arg(i)) for v, i in kwargs.items()])
  286. arg_str = ', '.join(arg_str)
  287. signature = '%s(%s)' % (name, arg_str)
  288. return module_path, signature
  289. def format_call(func, args, kwargs, object_name="Memory"):
  290. """ Returns a nicely formatted statement displaying the function
  291. call with the given arguments.
  292. """
  293. path, signature = format_signature(func, *args, **kwargs)
  294. msg = '%s\n[%s] Calling %s...\n%s' % (80 * '_', object_name,
  295. path, signature)
  296. return msg
  297. # XXX: Not using logging framework
  298. # self.debug(msg)