| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- """Numpy pickle compatibility functions."""
- import pickle
- import os
- import zlib
- import inspect
- from io import BytesIO
- from .numpy_pickle_utils import _ZFILE_PREFIX
- from .numpy_pickle_utils import Unpickler
- def hex_str(an_int):
- """Convert an int to an hexadecimal string."""
- return '{:#x}'.format(an_int)
- def asbytes(s):
- if isinstance(s, bytes):
- return s
- return s.encode('latin1')
- _MAX_LEN = len(hex_str(2 ** 64))
- _CHUNK_SIZE = 64 * 1024
- def read_zfile(file_handle):
- """Read the z-file and return the content as a string.
- Z-files are raw data compressed with zlib used internally by joblib
- for persistence. Backward compatibility is not guaranteed. Do not
- use for external purposes.
- """
- file_handle.seek(0)
- header_length = len(_ZFILE_PREFIX) + _MAX_LEN
- length = file_handle.read(header_length)
- length = length[len(_ZFILE_PREFIX):]
- length = int(length, 16)
- # With python2 and joblib version <= 0.8.4 compressed pickle header is one
- # character wider so we need to ignore an additional space if present.
- # Note: the first byte of the zlib data is guaranteed not to be a
- # space according to
- # https://tools.ietf.org/html/rfc6713#section-2.1
- next_byte = file_handle.read(1)
- if next_byte != b' ':
- # The zlib compressed data has started and we need to go back
- # one byte
- file_handle.seek(header_length)
- # We use the known length of the data to tell Zlib the size of the
- # buffer to allocate.
- data = zlib.decompress(file_handle.read(), 15, length)
- assert len(data) == length, (
- "Incorrect data length while decompressing %s."
- "The file could be corrupted." % file_handle)
- return data
- def write_zfile(file_handle, data, compress=1):
- """Write the data in the given file as a Z-file.
- Z-files are raw data compressed with zlib used internally by joblib
- for persistence. Backward compatibility is not guarantied. Do not
- use for external purposes.
- """
- file_handle.write(_ZFILE_PREFIX)
- length = hex_str(len(data))
- # Store the length of the data
- file_handle.write(asbytes(length.ljust(_MAX_LEN)))
- file_handle.write(zlib.compress(asbytes(data), compress))
- ###############################################################################
- # Utility objects for persistence.
- class NDArrayWrapper(object):
- """An object to be persisted instead of numpy arrays.
- The only thing this object does, is to carry the filename in which
- the array has been persisted, and the array subclass.
- """
- def __init__(self, filename, subclass, allow_mmap=True):
- """Constructor. Store the useful information for later."""
- self.filename = filename
- self.subclass = subclass
- self.allow_mmap = allow_mmap
- def read(self, unpickler):
- """Reconstruct the array."""
- filename = os.path.join(unpickler._dirname, self.filename)
- # Load the array from the disk
- # use getattr instead of self.allow_mmap to ensure backward compat
- # with NDArrayWrapper instances pickled with joblib < 0.9.0
- allow_mmap = getattr(self, 'allow_mmap', True)
- kwargs = {}
- if allow_mmap:
- kwargs['mmap_mode'] = unpickler.mmap_mode
- if "allow_pickle" in inspect.signature(unpickler.np.load).parameters:
- # Required in numpy 1.16.3 and later to aknowledge the security
- # risk.
- kwargs["allow_pickle"] = True
- array = unpickler.np.load(filename, **kwargs)
- # Reconstruct subclasses. This does not work with old
- # versions of numpy
- if (hasattr(array, '__array_prepare__') and
- self.subclass not in (unpickler.np.ndarray,
- unpickler.np.memmap)):
- # We need to reconstruct another subclass
- new_array = unpickler.np.core.multiarray._reconstruct(
- self.subclass, (0,), 'b')
- return new_array.__array_prepare__(array)
- else:
- return array
- class ZNDArrayWrapper(NDArrayWrapper):
- """An object to be persisted instead of numpy arrays.
- This object store the Zfile filename in which
- the data array has been persisted, and the meta information to
- retrieve it.
- The reason that we store the raw buffer data of the array and
- the meta information, rather than array representation routine
- (tobytes) is that it enables us to use completely the strided
- model to avoid memory copies (a and a.T store as fast). In
- addition saving the heavy information separately can avoid
- creating large temporary buffers when unpickling data with
- large arrays.
- """
- def __init__(self, filename, init_args, state):
- """Constructor. Store the useful information for later."""
- self.filename = filename
- self.state = state
- self.init_args = init_args
- def read(self, unpickler):
- """Reconstruct the array from the meta-information and the z-file."""
- # Here we a simply reproducing the unpickling mechanism for numpy
- # arrays
- filename = os.path.join(unpickler._dirname, self.filename)
- array = unpickler.np.core.multiarray._reconstruct(*self.init_args)
- with open(filename, 'rb') as f:
- data = read_zfile(f)
- state = self.state + (data,)
- array.__setstate__(state)
- return array
- class ZipNumpyUnpickler(Unpickler):
- """A subclass of the Unpickler to unpickle our numpy pickles."""
- dispatch = Unpickler.dispatch.copy()
- def __init__(self, filename, file_handle, mmap_mode=None):
- """Constructor."""
- self._filename = os.path.basename(filename)
- self._dirname = os.path.dirname(filename)
- self.mmap_mode = mmap_mode
- self.file_handle = self._open_pickle(file_handle)
- Unpickler.__init__(self, self.file_handle)
- try:
- import numpy as np
- except ImportError:
- np = None
- self.np = np
- def _open_pickle(self, file_handle):
- return BytesIO(read_zfile(file_handle))
- def load_build(self):
- """Set the state of a newly created object.
- We capture it to replace our place-holder objects,
- NDArrayWrapper, by the array we are interested in. We
- replace them directly in the stack of pickler.
- """
- Unpickler.load_build(self)
- if isinstance(self.stack[-1], NDArrayWrapper):
- if self.np is None:
- raise ImportError("Trying to unpickle an ndarray, "
- "but numpy didn't import correctly")
- nd_array_wrapper = self.stack.pop()
- array = nd_array_wrapper.read(self)
- self.stack.append(array)
- dispatch[pickle.BUILD[0]] = load_build
- def load_compatibility(filename):
- """Reconstruct a Python object from a file persisted with joblib.dump.
- This function ensures the compatibility with joblib old persistence format
- (<= 0.9.3).
- Parameters
- -----------
- filename: string
- The name of the file from which to load the object
- Returns
- -------
- result: any Python object
- The object stored in the file.
- See Also
- --------
- joblib.dump : function to save an object
- Notes
- -----
- This function can load numpy array files saved separately during the
- dump.
- """
- with open(filename, 'rb') as file_handle:
- # We are careful to open the file handle early and keep it open to
- # avoid race-conditions on renames. That said, if data is stored in
- # companion files, moving the directory will create a race when
- # joblib tries to access the companion files.
- unpickler = ZipNumpyUnpickler(filename, file_handle=file_handle)
- try:
- obj = unpickler.load()
- except UnicodeDecodeError as exc:
- # More user-friendly error message
- new_exc = ValueError(
- 'You may be trying to read with '
- 'python 3 a joblib pickle generated with python 2. '
- 'This feature is not supported by joblib.')
- new_exc.__cause__ = exc
- raise new_exc
- finally:
- if hasattr(unpickler, 'file_handle'):
- unpickler.file_handle.close()
- return obj
|