cloudpickle_wrapper.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import inspect
  2. from functools import partial
  3. try:
  4. from joblib.externals.cloudpickle import dumps, loads
  5. cloudpickle = True
  6. except ImportError:
  7. cloudpickle = False
  8. WRAP_CACHE = dict()
  9. class CloudpickledObjectWrapper(object):
  10. def __init__(self, obj, keep_wrapper=False):
  11. self._obj = obj
  12. self._keep_wrapper = keep_wrapper
  13. def __reduce__(self):
  14. _pickled_object = dumps(self._obj)
  15. if not self._keep_wrapper:
  16. return loads, (_pickled_object,)
  17. return _reconstruct_wrapper, (_pickled_object, self._keep_wrapper)
  18. def __getattr__(self, attr):
  19. # Ensure that the wrapped object can be used seemlessly as the
  20. # previous object.
  21. if attr not in ['_obj', '_keep_wrapper']:
  22. return getattr(self._obj, attr)
  23. return getattr(self, attr)
  24. # Make sure the wrapped object conserves the callable property
  25. class CallableObjectWrapper(CloudpickledObjectWrapper):
  26. def __call__(self, *args, **kwargs):
  27. return self._obj(*args, **kwargs)
  28. def _wrap_non_picklable_objects(obj, keep_wrapper):
  29. if callable(obj):
  30. return CallableObjectWrapper(obj, keep_wrapper=keep_wrapper)
  31. return CloudpickledObjectWrapper(obj, keep_wrapper=keep_wrapper)
  32. def _reconstruct_wrapper(_pickled_object, keep_wrapper):
  33. obj = loads(_pickled_object)
  34. return _wrap_non_picklable_objects(obj, keep_wrapper)
  35. def _wrap_objects_when_needed(obj):
  36. # Function to introspect an object and decide if it should be wrapped or
  37. # not.
  38. if not cloudpickle:
  39. return obj
  40. need_wrap = "__main__" in getattr(obj, "__module__", "")
  41. if isinstance(obj, partial):
  42. return partial(
  43. _wrap_objects_when_needed(obj.func),
  44. *[_wrap_objects_when_needed(a) for a in obj.args],
  45. **{k: _wrap_objects_when_needed(v)
  46. for k, v in obj.keywords.items()}
  47. )
  48. if callable(obj):
  49. # Need wrap if the object is a function defined in a local scope of
  50. # another function.
  51. func_code = getattr(obj, "__code__", "")
  52. need_wrap |= getattr(func_code, "co_flags", 0) & inspect.CO_NESTED
  53. # Need wrap if the obj is a lambda expression
  54. func_name = getattr(obj, "__name__", "")
  55. need_wrap |= "<lambda>" in func_name
  56. if not need_wrap:
  57. return obj
  58. wrapped_obj = WRAP_CACHE.get(obj)
  59. if wrapped_obj is None:
  60. wrapped_obj = _wrap_non_picklable_objects(obj, keep_wrapper=False)
  61. WRAP_CACHE[obj] = wrapped_obj
  62. return wrapped_obj
  63. def wrap_non_picklable_objects(obj, keep_wrapper=True):
  64. """Wrapper for non-picklable object to use cloudpickle to serialize them.
  65. Note that this wrapper tends to slow down the serialization process as it
  66. is done with cloudpickle which is typically slower compared to pickle. The
  67. proper way to solve serialization issues is to avoid defining functions and
  68. objects in the main scripts and to implement __reduce__ functions for
  69. complex classes.
  70. """
  71. if not cloudpickle:
  72. raise ImportError("could not from joblib.externals import cloudpickle. Please install "
  73. "cloudpickle to allow extended serialization. "
  74. "(`pip install cloudpickle`).")
  75. # If obj is a class, create a CloudpickledClassWrapper which instantiates
  76. # the object internally and wrap it directly in a CloudpickledObjectWrapper
  77. if inspect.isclass(obj):
  78. class CloudpickledClassWrapper(CloudpickledObjectWrapper):
  79. def __init__(self, *args, **kwargs):
  80. self._obj = obj(*args, **kwargs)
  81. self._keep_wrapper = keep_wrapper
  82. CloudpickledClassWrapper.__name__ = obj.__name__
  83. return CloudpickledClassWrapper
  84. # If obj is an instance of a class, just wrap it in a regular
  85. # CloudpickledObjectWrapper
  86. return _wrap_non_picklable_objects(obj, keep_wrapper=keep_wrapper)