test_numpy_pickle.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. """Test the numpy pickler as a replacement of the standard pickler."""
  2. import copy
  3. import os
  4. import random
  5. import re
  6. import io
  7. import warnings
  8. import gzip
  9. import zlib
  10. import bz2
  11. import pickle
  12. import socket
  13. from contextlib import closing
  14. import mmap
  15. try:
  16. import lzma
  17. except ImportError:
  18. lzma = None
  19. import pytest
  20. from joblib.test.common import np, with_numpy, with_lz4, without_lz4
  21. from joblib.test.common import with_memory_profiler, memory_used
  22. from joblib.testing import parametrize, raises, SkipTest, warns
  23. # numpy_pickle is not a drop-in replacement of pickle, as it takes
  24. # filenames instead of open files as arguments.
  25. from joblib import numpy_pickle, register_compressor
  26. from joblib.test import data
  27. from joblib.numpy_pickle_utils import _IO_BUFFER_SIZE
  28. from joblib.numpy_pickle_utils import _detect_compressor
  29. from joblib.compressor import (_COMPRESSORS, _LZ4_PREFIX, CompressorWrapper,
  30. LZ4_NOT_INSTALLED_ERROR, BinaryZlibFile)
  31. ###############################################################################
  32. # Define a list of standard types.
  33. # Borrowed from dill, initial author: Micheal McKerns:
  34. # http://dev.danse.us/trac/pathos/browser/dill/dill_test2.py
  35. typelist = []
  36. # testing types
  37. _none = None
  38. typelist.append(_none)
  39. _type = type
  40. typelist.append(_type)
  41. _bool = bool(1)
  42. typelist.append(_bool)
  43. _int = int(1)
  44. typelist.append(_int)
  45. _float = float(1)
  46. typelist.append(_float)
  47. _complex = complex(1)
  48. typelist.append(_complex)
  49. _string = str(1)
  50. typelist.append(_string)
  51. _tuple = ()
  52. typelist.append(_tuple)
  53. _list = []
  54. typelist.append(_list)
  55. _dict = {}
  56. typelist.append(_dict)
  57. _builtin = len
  58. typelist.append(_builtin)
  59. def _function(x):
  60. yield x
  61. class _class:
  62. def _method(self):
  63. pass
  64. class _newclass(object):
  65. def _method(self):
  66. pass
  67. typelist.append(_function)
  68. typelist.append(_class)
  69. typelist.append(_newclass) # <type 'type'>
  70. _instance = _class()
  71. typelist.append(_instance)
  72. _object = _newclass()
  73. typelist.append(_object) # <type 'class'>
  74. ###############################################################################
  75. # Tests
  76. @parametrize('compress', [0, 1])
  77. @parametrize('member', typelist)
  78. def test_standard_types(tmpdir, compress, member):
  79. # Test pickling and saving with standard types.
  80. filename = tmpdir.join('test.pkl').strpath
  81. numpy_pickle.dump(member, filename, compress=compress)
  82. _member = numpy_pickle.load(filename)
  83. # We compare the pickled instance to the reloaded one only if it
  84. # can be compared to a copied one
  85. if member == copy.deepcopy(member):
  86. assert member == _member
  87. def test_value_error():
  88. # Test inverting the input arguments to dump
  89. with raises(ValueError):
  90. numpy_pickle.dump('foo', dict())
  91. @parametrize('wrong_compress', [-1, 10, dict()])
  92. def test_compress_level_error(wrong_compress):
  93. # Verify that passing an invalid compress argument raises an error.
  94. exception_msg = ('Non valid compress level given: '
  95. '"{0}"'.format(wrong_compress))
  96. with raises(ValueError) as excinfo:
  97. numpy_pickle.dump('dummy', 'foo', compress=wrong_compress)
  98. excinfo.match(exception_msg)
  99. @with_numpy
  100. @parametrize('compress', [False, True, 0, 3, 'zlib'])
  101. def test_numpy_persistence(tmpdir, compress):
  102. filename = tmpdir.join('test.pkl').strpath
  103. rnd = np.random.RandomState(0)
  104. a = rnd.random_sample((10, 2))
  105. # We use 'a.T' to have a non C-contiguous array.
  106. for index, obj in enumerate(((a,), (a.T,), (a, a), [a, a, a])):
  107. filenames = numpy_pickle.dump(obj, filename, compress=compress)
  108. # All is cached in one file
  109. assert len(filenames) == 1
  110. # Check that only one file was created
  111. assert filenames[0] == filename
  112. # Check that this file does exist
  113. assert os.path.exists(filenames[0])
  114. # Unpickle the object
  115. obj_ = numpy_pickle.load(filename)
  116. # Check that the items are indeed arrays
  117. for item in obj_:
  118. assert isinstance(item, np.ndarray)
  119. # And finally, check that all the values are equal.
  120. np.testing.assert_array_equal(np.array(obj), np.array(obj_))
  121. # Now test with array subclasses
  122. for obj in (np.matrix(np.zeros(10)),
  123. np.memmap(filename + 'mmap',
  124. mode='w+', shape=4, dtype=np.float)):
  125. filenames = numpy_pickle.dump(obj, filename, compress=compress)
  126. # All is cached in one file
  127. assert len(filenames) == 1
  128. obj_ = numpy_pickle.load(filename)
  129. if (type(obj) is not np.memmap and
  130. hasattr(obj, '__array_prepare__')):
  131. # We don't reconstruct memmaps
  132. assert isinstance(obj_, type(obj))
  133. np.testing.assert_array_equal(obj_, obj)
  134. # Test with an object containing multiple numpy arrays
  135. obj = ComplexTestObject()
  136. filenames = numpy_pickle.dump(obj, filename, compress=compress)
  137. # All is cached in one file
  138. assert len(filenames) == 1
  139. obj_loaded = numpy_pickle.load(filename)
  140. assert isinstance(obj_loaded, type(obj))
  141. np.testing.assert_array_equal(obj_loaded.array_float, obj.array_float)
  142. np.testing.assert_array_equal(obj_loaded.array_int, obj.array_int)
  143. np.testing.assert_array_equal(obj_loaded.array_obj, obj.array_obj)
  144. @with_numpy
  145. def test_numpy_persistence_bufferred_array_compression(tmpdir):
  146. big_array = np.ones((_IO_BUFFER_SIZE + 100), dtype=np.uint8)
  147. filename = tmpdir.join('test.pkl').strpath
  148. numpy_pickle.dump(big_array, filename, compress=True)
  149. arr_reloaded = numpy_pickle.load(filename)
  150. np.testing.assert_array_equal(big_array, arr_reloaded)
  151. @with_numpy
  152. def test_memmap_persistence(tmpdir):
  153. rnd = np.random.RandomState(0)
  154. a = rnd.random_sample(10)
  155. filename = tmpdir.join('test1.pkl').strpath
  156. numpy_pickle.dump(a, filename)
  157. b = numpy_pickle.load(filename, mmap_mode='r')
  158. assert isinstance(b, np.memmap)
  159. # Test with an object containing multiple numpy arrays
  160. filename = tmpdir.join('test2.pkl').strpath
  161. obj = ComplexTestObject()
  162. numpy_pickle.dump(obj, filename)
  163. obj_loaded = numpy_pickle.load(filename, mmap_mode='r')
  164. assert isinstance(obj_loaded, type(obj))
  165. assert isinstance(obj_loaded.array_float, np.memmap)
  166. assert not obj_loaded.array_float.flags.writeable
  167. assert isinstance(obj_loaded.array_int, np.memmap)
  168. assert not obj_loaded.array_int.flags.writeable
  169. # Memory map not allowed for numpy object arrays
  170. assert not isinstance(obj_loaded.array_obj, np.memmap)
  171. np.testing.assert_array_equal(obj_loaded.array_float,
  172. obj.array_float)
  173. np.testing.assert_array_equal(obj_loaded.array_int,
  174. obj.array_int)
  175. np.testing.assert_array_equal(obj_loaded.array_obj,
  176. obj.array_obj)
  177. # Test we can write in memmapped arrays
  178. obj_loaded = numpy_pickle.load(filename, mmap_mode='r+')
  179. assert obj_loaded.array_float.flags.writeable
  180. obj_loaded.array_float[0:10] = 10.0
  181. assert obj_loaded.array_int.flags.writeable
  182. obj_loaded.array_int[0:10] = 10
  183. obj_reloaded = numpy_pickle.load(filename, mmap_mode='r')
  184. np.testing.assert_array_equal(obj_reloaded.array_float,
  185. obj_loaded.array_float)
  186. np.testing.assert_array_equal(obj_reloaded.array_int,
  187. obj_loaded.array_int)
  188. # Test w+ mode is caught and the mode has switched to r+
  189. numpy_pickle.load(filename, mmap_mode='w+')
  190. assert obj_loaded.array_int.flags.writeable
  191. assert obj_loaded.array_int.mode == 'r+'
  192. assert obj_loaded.array_float.flags.writeable
  193. assert obj_loaded.array_float.mode == 'r+'
  194. @with_numpy
  195. def test_memmap_persistence_mixed_dtypes(tmpdir):
  196. # loading datastructures that have sub-arrays with dtype=object
  197. # should not prevent memmapping on fixed size dtype sub-arrays.
  198. rnd = np.random.RandomState(0)
  199. a = rnd.random_sample(10)
  200. b = np.array([1, 'b'], dtype=object)
  201. construct = (a, b)
  202. filename = tmpdir.join('test.pkl').strpath
  203. numpy_pickle.dump(construct, filename)
  204. a_clone, b_clone = numpy_pickle.load(filename, mmap_mode='r')
  205. # the floating point array has been memory mapped
  206. assert isinstance(a_clone, np.memmap)
  207. # the object-dtype array has been loaded in memory
  208. assert not isinstance(b_clone, np.memmap)
  209. @with_numpy
  210. def test_masked_array_persistence(tmpdir):
  211. # The special-case picker fails, because saving masked_array
  212. # not implemented, but it just delegates to the standard pickler.
  213. rnd = np.random.RandomState(0)
  214. a = rnd.random_sample(10)
  215. a = np.ma.masked_greater(a, 0.5)
  216. filename = tmpdir.join('test.pkl').strpath
  217. numpy_pickle.dump(a, filename)
  218. b = numpy_pickle.load(filename, mmap_mode='r')
  219. assert isinstance(b, np.ma.masked_array)
  220. @with_numpy
  221. def test_compress_mmap_mode_warning(tmpdir):
  222. # Test the warning in case of compress + mmap_mode
  223. rnd = np.random.RandomState(0)
  224. a = rnd.random_sample(10)
  225. this_filename = tmpdir.join('test.pkl').strpath
  226. numpy_pickle.dump(a, this_filename, compress=1)
  227. with warns(UserWarning) as warninfo:
  228. numpy_pickle.load(this_filename, mmap_mode='r+')
  229. assert len(warninfo) == 1
  230. assert (str(warninfo[0].message) ==
  231. 'mmap_mode "%(mmap_mode)s" is not compatible with compressed '
  232. 'file %(filename)s. "%(mmap_mode)s" flag will be ignored.' %
  233. {'filename': this_filename, 'mmap_mode': 'r+'})
  234. @with_numpy
  235. @parametrize('cache_size', [None, 0, 10])
  236. def test_cache_size_warning(tmpdir, cache_size):
  237. # Check deprecation warning raised when cache size is not None
  238. filename = tmpdir.join('test.pkl').strpath
  239. rnd = np.random.RandomState(0)
  240. a = rnd.random_sample((10, 2))
  241. warnings.simplefilter("always")
  242. with warns(None) as warninfo:
  243. numpy_pickle.dump(a, filename, cache_size=cache_size)
  244. expected_nb_warnings = 1 if cache_size is not None else 0
  245. assert len(warninfo) == expected_nb_warnings
  246. for w in warninfo:
  247. assert w.category == DeprecationWarning
  248. assert (str(w.message) ==
  249. "Please do not set 'cache_size' in joblib.dump, this "
  250. "parameter has no effect and will be removed. You "
  251. "used 'cache_size={0}'".format(cache_size))
  252. @with_numpy
  253. @with_memory_profiler
  254. @parametrize('compress', [True, False])
  255. def test_memory_usage(tmpdir, compress):
  256. # Verify memory stays within expected bounds.
  257. filename = tmpdir.join('test.pkl').strpath
  258. small_array = np.ones((10, 10))
  259. big_array = np.ones(shape=100 * int(1e6), dtype=np.uint8)
  260. small_matrix = np.matrix(small_array)
  261. big_matrix = np.matrix(big_array)
  262. for obj in (small_array, big_array, small_matrix, big_matrix):
  263. size = obj.nbytes / 1e6
  264. obj_filename = filename + str(np.random.randint(0, 1000))
  265. mem_used = memory_used(numpy_pickle.dump,
  266. obj, obj_filename, compress=compress)
  267. # The memory used to dump the object shouldn't exceed the buffer
  268. # size used to write array chunks (16MB).
  269. write_buf_size = _IO_BUFFER_SIZE + 16 * 1024 ** 2 / 1e6
  270. assert mem_used <= write_buf_size
  271. mem_used = memory_used(numpy_pickle.load, obj_filename)
  272. # memory used should be less than array size + buffer size used to
  273. # read the array chunk by chunk.
  274. read_buf_size = 32 + _IO_BUFFER_SIZE # MiB
  275. assert mem_used < size + read_buf_size
  276. @with_numpy
  277. def test_compressed_pickle_dump_and_load(tmpdir):
  278. expected_list = [np.arange(5, dtype=np.dtype('<i8')),
  279. np.arange(5, dtype=np.dtype('>i8')),
  280. np.arange(5, dtype=np.dtype('<f8')),
  281. np.arange(5, dtype=np.dtype('>f8')),
  282. np.array([1, 'abc', {'a': 1, 'b': 2}], dtype='O'),
  283. np.arange(256, dtype=np.uint8).tobytes(),
  284. # np.matrix is a subclass of np.ndarray, here we want
  285. # to verify this type of object is correctly unpickled
  286. # among versions.
  287. np.matrix([0, 1, 2], dtype=np.dtype('<i8')),
  288. np.matrix([0, 1, 2], dtype=np.dtype('>i8')),
  289. u"C'est l'\xe9t\xe9 !"]
  290. fname = tmpdir.join('temp.pkl.gz').strpath
  291. dumped_filenames = numpy_pickle.dump(expected_list, fname, compress=1)
  292. assert len(dumped_filenames) == 1
  293. result_list = numpy_pickle.load(fname)
  294. for result, expected in zip(result_list, expected_list):
  295. if isinstance(expected, np.ndarray):
  296. assert result.dtype == expected.dtype
  297. np.testing.assert_equal(result, expected)
  298. else:
  299. assert result == expected
  300. def _check_pickle(filename, expected_list):
  301. """Helper function to test joblib pickle content.
  302. Note: currently only pickles containing an iterable are supported
  303. by this function.
  304. """
  305. version_match = re.match(r'.+py(\d)(\d).+', filename)
  306. py_version_used_for_writing = int(version_match.group(1))
  307. py_version_to_default_pickle_protocol = {2: 2, 3: 3}
  308. pickle_reading_protocol = py_version_to_default_pickle_protocol.get(3, 4)
  309. pickle_writing_protocol = py_version_to_default_pickle_protocol.get(
  310. py_version_used_for_writing, 4)
  311. if pickle_reading_protocol >= pickle_writing_protocol:
  312. try:
  313. with warns(None) as warninfo:
  314. warnings.simplefilter('always')
  315. warnings.filterwarnings(
  316. 'ignore', module='numpy',
  317. message='The compiler package is deprecated')
  318. result_list = numpy_pickle.load(filename)
  319. filename_base = os.path.basename(filename)
  320. expected_nb_warnings = 1 if ("_0.9" in filename_base or
  321. "_0.8.4" in filename_base) else 0
  322. assert len(warninfo) == expected_nb_warnings
  323. for w in warninfo:
  324. assert w.category == DeprecationWarning
  325. assert (str(w.message) ==
  326. "The file '{0}' has been generated with a joblib "
  327. "version less than 0.10. Please regenerate this "
  328. "pickle file.".format(filename))
  329. for result, expected in zip(result_list, expected_list):
  330. if isinstance(expected, np.ndarray):
  331. assert result.dtype == expected.dtype
  332. np.testing.assert_equal(result, expected)
  333. else:
  334. assert result == expected
  335. except Exception as exc:
  336. # When trying to read with python 3 a pickle generated
  337. # with python 2 we expect a user-friendly error
  338. if py_version_used_for_writing == 2:
  339. assert isinstance(exc, ValueError)
  340. message = ('You may be trying to read with '
  341. 'python 3 a joblib pickle generated with python 2.')
  342. assert message in str(exc)
  343. elif filename.endswith('.lz4') and with_lz4.args[0]:
  344. assert isinstance(exc, ValueError)
  345. assert LZ4_NOT_INSTALLED_ERROR in str(exc)
  346. else:
  347. raise
  348. else:
  349. # Pickle protocol used for writing is too high. We expect a
  350. # "unsupported pickle protocol" error message
  351. try:
  352. numpy_pickle.load(filename)
  353. raise AssertionError('Numpy pickle loading should '
  354. 'have raised a ValueError exception')
  355. except ValueError as e:
  356. message = 'unsupported pickle protocol: {0}'.format(
  357. pickle_writing_protocol)
  358. assert message in str(e.args)
  359. @with_numpy
  360. def test_joblib_pickle_across_python_versions():
  361. # We need to be specific about dtypes in particular endianness
  362. # because the pickles can be generated on one architecture and
  363. # the tests run on another one. See
  364. # https://github.com/joblib/joblib/issues/279.
  365. expected_list = [np.arange(5, dtype=np.dtype('<i8')),
  366. np.arange(5, dtype=np.dtype('<f8')),
  367. np.array([1, 'abc', {'a': 1, 'b': 2}], dtype='O'),
  368. np.arange(256, dtype=np.uint8).tobytes(),
  369. # np.matrix is a subclass of np.ndarray, here we want
  370. # to verify this type of object is correctly unpickled
  371. # among versions.
  372. np.matrix([0, 1, 2], dtype=np.dtype('<i8')),
  373. u"C'est l'\xe9t\xe9 !"]
  374. # Testing all the compressed and non compressed
  375. # pickles in joblib/test/data. These pickles were generated by
  376. # the joblib/test/data/create_numpy_pickle.py script for the
  377. # relevant python, joblib and numpy versions.
  378. test_data_dir = os.path.dirname(os.path.abspath(data.__file__))
  379. pickle_extensions = ('.pkl', '.gz', '.gzip', '.bz2', 'lz4')
  380. if lzma is not None:
  381. pickle_extensions += ('.xz', '.lzma')
  382. pickle_filenames = [os.path.join(test_data_dir, fn)
  383. for fn in os.listdir(test_data_dir)
  384. if any(fn.endswith(ext) for ext in pickle_extensions)]
  385. for fname in pickle_filenames:
  386. _check_pickle(fname, expected_list)
  387. @parametrize('compress_tuple', [('zlib', 3), ('gzip', 3)])
  388. def test_compress_tuple_argument(tmpdir, compress_tuple):
  389. # Verify the tuple is correctly taken into account.
  390. filename = tmpdir.join('test.pkl').strpath
  391. numpy_pickle.dump("dummy", filename,
  392. compress=compress_tuple)
  393. # Verify the file contains the right magic number
  394. with open(filename, 'rb') as f:
  395. assert _detect_compressor(f) == compress_tuple[0]
  396. @parametrize('compress_tuple,message',
  397. [(('zlib', 3, 'extra'), # wrong compress tuple
  398. 'Compress argument tuple should contain exactly 2 elements'),
  399. (('wrong', 3), # wrong compress method
  400. 'Non valid compression method given: "{}"'.format('wrong')),
  401. (('zlib', 'wrong'), # wrong compress level
  402. 'Non valid compress level given: "{}"'.format('wrong'))])
  403. def test_compress_tuple_argument_exception(tmpdir, compress_tuple, message):
  404. filename = tmpdir.join('test.pkl').strpath
  405. # Verify setting a wrong compress tuple raises a ValueError.
  406. with raises(ValueError) as excinfo:
  407. numpy_pickle.dump('dummy', filename, compress=compress_tuple)
  408. excinfo.match(message)
  409. @parametrize('compress_string', ['zlib', 'gzip'])
  410. def test_compress_string_argument(tmpdir, compress_string):
  411. # Verify the string is correctly taken into account.
  412. filename = tmpdir.join('test.pkl').strpath
  413. numpy_pickle.dump("dummy", filename,
  414. compress=compress_string)
  415. # Verify the file contains the right magic number
  416. with open(filename, 'rb') as f:
  417. assert _detect_compressor(f) == compress_string
  418. @with_numpy
  419. @parametrize('compress', [1, 3, 6])
  420. @parametrize('cmethod', _COMPRESSORS)
  421. def test_joblib_compression_formats(tmpdir, compress, cmethod):
  422. filename = tmpdir.join('test.pkl').strpath
  423. objects = (np.ones(shape=(100, 100), dtype='f8'),
  424. range(10),
  425. {'a': 1, 2: 'b'}, [], (), {}, 0, 1.0)
  426. if cmethod in ("lzma", "xz") and lzma is None:
  427. pytest.skip("lzma is support not available")
  428. elif cmethod == 'lz4' and with_lz4.args[0]:
  429. # Skip the test if lz4 is not installed. We here use the with_lz4
  430. # skipif fixture whose argument is True when lz4 is not installed
  431. pytest.skip("lz4 is not installed.")
  432. dump_filename = filename + "." + cmethod
  433. for obj in objects:
  434. numpy_pickle.dump(obj, dump_filename, compress=(cmethod, compress))
  435. # Verify the file contains the right magic number
  436. with open(dump_filename, 'rb') as f:
  437. assert _detect_compressor(f) == cmethod
  438. # Verify the reloaded object is correct
  439. obj_reloaded = numpy_pickle.load(dump_filename)
  440. assert isinstance(obj_reloaded, type(obj))
  441. if isinstance(obj, np.ndarray):
  442. np.testing.assert_array_equal(obj_reloaded, obj)
  443. else:
  444. assert obj_reloaded == obj
  445. def _gzip_file_decompress(source_filename, target_filename):
  446. """Decompress a gzip file."""
  447. with closing(gzip.GzipFile(source_filename, "rb")) as fo:
  448. buf = fo.read()
  449. with open(target_filename, "wb") as fo:
  450. fo.write(buf)
  451. def _zlib_file_decompress(source_filename, target_filename):
  452. """Decompress a zlib file."""
  453. with open(source_filename, 'rb') as fo:
  454. buf = zlib.decompress(fo.read())
  455. with open(target_filename, 'wb') as fo:
  456. fo.write(buf)
  457. @parametrize('extension,decompress',
  458. [('.z', _zlib_file_decompress),
  459. ('.gz', _gzip_file_decompress)])
  460. def test_load_externally_decompressed_files(tmpdir, extension, decompress):
  461. # Test that BinaryZlibFile generates valid gzip and zlib compressed files.
  462. obj = "a string to persist"
  463. filename_raw = tmpdir.join('test.pkl').strpath
  464. filename_compressed = filename_raw + extension
  465. # Use automatic extension detection to compress with the right method.
  466. numpy_pickle.dump(obj, filename_compressed)
  467. # Decompress with the corresponding method
  468. decompress(filename_compressed, filename_raw)
  469. # Test that the uncompressed pickle can be loaded and
  470. # that the result is correct.
  471. obj_reloaded = numpy_pickle.load(filename_raw)
  472. assert obj == obj_reloaded
  473. @parametrize('extension,cmethod',
  474. # valid compressor extensions
  475. [('.z', 'zlib'),
  476. ('.gz', 'gzip'),
  477. ('.bz2', 'bz2'),
  478. ('.lzma', 'lzma'),
  479. ('.xz', 'xz'),
  480. # invalid compressor extensions
  481. ('.pkl', 'not-compressed'),
  482. ('', 'not-compressed')])
  483. def test_compression_using_file_extension(tmpdir, extension, cmethod):
  484. if cmethod in ("lzma", "xz") and lzma is None:
  485. pytest.skip("lzma is missing")
  486. # test that compression method corresponds to the given filename extension.
  487. filename = tmpdir.join('test.pkl').strpath
  488. obj = "object to dump"
  489. dump_fname = filename + extension
  490. numpy_pickle.dump(obj, dump_fname)
  491. # Verify the file contains the right magic number
  492. with open(dump_fname, 'rb') as f:
  493. assert _detect_compressor(f) == cmethod
  494. # Verify the reloaded object is correct
  495. obj_reloaded = numpy_pickle.load(dump_fname)
  496. assert isinstance(obj_reloaded, type(obj))
  497. assert obj_reloaded == obj
  498. @with_numpy
  499. def test_file_handle_persistence(tmpdir):
  500. objs = [np.random.random((10, 10)),
  501. "some data",
  502. np.matrix([0, 1, 2])]
  503. fobjs = [bz2.BZ2File, gzip.GzipFile]
  504. if lzma is not None:
  505. fobjs += [lzma.LZMAFile]
  506. filename = tmpdir.join('test.pkl').strpath
  507. for obj in objs:
  508. for fobj in fobjs:
  509. with fobj(filename, 'wb') as f:
  510. numpy_pickle.dump(obj, f)
  511. # using the same decompressor prevents from internally
  512. # decompress again.
  513. with fobj(filename, 'rb') as f:
  514. obj_reloaded = numpy_pickle.load(f)
  515. # when needed, the correct decompressor should be used when
  516. # passing a raw file handle.
  517. with open(filename, 'rb') as f:
  518. obj_reloaded_2 = numpy_pickle.load(f)
  519. if isinstance(obj, np.ndarray):
  520. np.testing.assert_array_equal(obj_reloaded, obj)
  521. np.testing.assert_array_equal(obj_reloaded_2, obj)
  522. else:
  523. assert obj_reloaded == obj
  524. assert obj_reloaded_2 == obj
  525. @with_numpy
  526. def test_in_memory_persistence():
  527. objs = [np.random.random((10, 10)),
  528. "some data",
  529. np.matrix([0, 1, 2])]
  530. for obj in objs:
  531. f = io.BytesIO()
  532. numpy_pickle.dump(obj, f)
  533. obj_reloaded = numpy_pickle.load(f)
  534. if isinstance(obj, np.ndarray):
  535. np.testing.assert_array_equal(obj_reloaded, obj)
  536. else:
  537. assert obj_reloaded == obj
  538. @with_numpy
  539. def test_file_handle_persistence_mmap(tmpdir):
  540. obj = np.random.random((10, 10))
  541. filename = tmpdir.join('test.pkl').strpath
  542. with open(filename, 'wb') as f:
  543. numpy_pickle.dump(obj, f)
  544. with open(filename, 'rb') as f:
  545. obj_reloaded = numpy_pickle.load(f, mmap_mode='r+')
  546. np.testing.assert_array_equal(obj_reloaded, obj)
  547. @with_numpy
  548. def test_file_handle_persistence_compressed_mmap(tmpdir):
  549. obj = np.random.random((10, 10))
  550. filename = tmpdir.join('test.pkl').strpath
  551. with open(filename, 'wb') as f:
  552. numpy_pickle.dump(obj, f, compress=('gzip', 3))
  553. with closing(gzip.GzipFile(filename, 'rb')) as f:
  554. with warns(UserWarning) as warninfo:
  555. numpy_pickle.load(f, mmap_mode='r+')
  556. assert len(warninfo) == 1
  557. assert (str(warninfo[0].message) ==
  558. '"%(fileobj)r" is not a raw file, mmap_mode "%(mmap_mode)s" '
  559. 'flag will be ignored.' % {'fileobj': f, 'mmap_mode': 'r+'})
  560. @with_numpy
  561. def test_file_handle_persistence_in_memory_mmap():
  562. obj = np.random.random((10, 10))
  563. buf = io.BytesIO()
  564. numpy_pickle.dump(obj, buf)
  565. with warns(UserWarning) as warninfo:
  566. numpy_pickle.load(buf, mmap_mode='r+')
  567. assert len(warninfo) == 1
  568. assert (str(warninfo[0].message) ==
  569. 'In memory persistence is not compatible with mmap_mode '
  570. '"%(mmap_mode)s" flag passed. mmap_mode option will be '
  571. 'ignored.' % {'mmap_mode': 'r+'})
  572. @parametrize('data', [b'a little data as bytes.',
  573. # More bytes
  574. 10000 * "{}".format(
  575. random.randint(0, 1000) * 1000).encode('latin-1')],
  576. ids=["a little data as bytes.", "a large data as bytes."])
  577. @parametrize('compress_level', [1, 3, 9])
  578. def test_binary_zlibfile(tmpdir, data, compress_level):
  579. filename = tmpdir.join('test.pkl').strpath
  580. # Regular cases
  581. with open(filename, 'wb') as f:
  582. with BinaryZlibFile(f, 'wb',
  583. compresslevel=compress_level) as fz:
  584. assert fz.writable()
  585. fz.write(data)
  586. assert fz.fileno() == f.fileno()
  587. with raises(io.UnsupportedOperation):
  588. fz._check_can_read()
  589. with raises(io.UnsupportedOperation):
  590. fz._check_can_seek()
  591. assert fz.closed
  592. with raises(ValueError):
  593. fz._check_not_closed()
  594. with open(filename, 'rb') as f:
  595. with BinaryZlibFile(f) as fz:
  596. assert fz.readable()
  597. assert fz.seekable()
  598. assert fz.fileno() == f.fileno()
  599. assert fz.read() == data
  600. with raises(io.UnsupportedOperation):
  601. fz._check_can_write()
  602. assert fz.seekable()
  603. fz.seek(0)
  604. assert fz.tell() == 0
  605. assert fz.closed
  606. # Test with a filename as input
  607. with BinaryZlibFile(filename, 'wb',
  608. compresslevel=compress_level) as fz:
  609. assert fz.writable()
  610. fz.write(data)
  611. with BinaryZlibFile(filename, 'rb') as fz:
  612. assert fz.read() == data
  613. assert fz.seekable()
  614. # Test without context manager
  615. fz = BinaryZlibFile(filename, 'wb', compresslevel=compress_level)
  616. assert fz.writable()
  617. fz.write(data)
  618. fz.close()
  619. fz = BinaryZlibFile(filename, 'rb')
  620. assert fz.read() == data
  621. fz.close()
  622. @parametrize('bad_value', [-1, 10, 15, 'a', (), {}])
  623. def test_binary_zlibfile_bad_compression_levels(tmpdir, bad_value):
  624. filename = tmpdir.join('test.pkl').strpath
  625. with raises(ValueError) as excinfo:
  626. BinaryZlibFile(filename, 'wb', compresslevel=bad_value)
  627. pattern = re.escape("'compresslevel' must be an integer between 1 and 9. "
  628. "You provided 'compresslevel={}'".format(bad_value))
  629. excinfo.match(pattern)
  630. @parametrize('bad_mode', ['a', 'x', 'r', 'w', 1, 2])
  631. def test_binary_zlibfile_invalid_modes(tmpdir, bad_mode):
  632. filename = tmpdir.join('test.pkl').strpath
  633. with raises(ValueError) as excinfo:
  634. BinaryZlibFile(filename, bad_mode)
  635. excinfo.match("Invalid mode")
  636. @parametrize('bad_file', [1, (), {}])
  637. def test_binary_zlibfile_invalid_filename_type(bad_file):
  638. with raises(TypeError) as excinfo:
  639. BinaryZlibFile(bad_file, 'rb')
  640. excinfo.match("filename must be a str or bytes object, or a file")
  641. ###############################################################################
  642. # Test dumping array subclasses
  643. if np is not None:
  644. class SubArray(np.ndarray):
  645. def __reduce__(self):
  646. return _load_sub_array, (np.asarray(self), )
  647. def _load_sub_array(arr):
  648. d = SubArray(arr.shape)
  649. d[:] = arr
  650. return d
  651. class ComplexTestObject:
  652. """A complex object containing numpy arrays as attributes."""
  653. def __init__(self):
  654. self.array_float = np.arange(100, dtype='float64')
  655. self.array_int = np.ones(100, dtype='int32')
  656. self.array_obj = np.array(['a', 10, 20.0], dtype='object')
  657. @with_numpy
  658. def test_numpy_subclass(tmpdir):
  659. filename = tmpdir.join('test.pkl').strpath
  660. a = SubArray((10,))
  661. numpy_pickle.dump(a, filename)
  662. c = numpy_pickle.load(filename)
  663. assert isinstance(c, SubArray)
  664. np.testing.assert_array_equal(c, a)
  665. def test_pathlib(tmpdir):
  666. try:
  667. from pathlib import Path
  668. except ImportError:
  669. pass
  670. else:
  671. filename = tmpdir.join('test.pkl').strpath
  672. value = 123
  673. numpy_pickle.dump(value, Path(filename))
  674. assert numpy_pickle.load(filename) == value
  675. numpy_pickle.dump(value, filename)
  676. assert numpy_pickle.load(Path(filename)) == value
  677. @with_numpy
  678. def test_non_contiguous_array_pickling(tmpdir):
  679. filename = tmpdir.join('test.pkl').strpath
  680. for array in [ # Array that triggers a contiguousness issue with nditer,
  681. # see https://github.com/joblib/joblib/pull/352 and see
  682. # https://github.com/joblib/joblib/pull/353
  683. np.asfortranarray([[1, 2], [3, 4]])[1:],
  684. # Non contiguous array with works fine with nditer
  685. np.ones((10, 50, 20), order='F')[:, :1, :]]:
  686. assert not array.flags.c_contiguous
  687. assert not array.flags.f_contiguous
  688. numpy_pickle.dump(array, filename)
  689. array_reloaded = numpy_pickle.load(filename)
  690. np.testing.assert_array_equal(array_reloaded, array)
  691. @with_numpy
  692. def test_pickle_highest_protocol(tmpdir):
  693. # ensure persistence of a numpy array is valid even when using
  694. # the pickle HIGHEST_PROTOCOL.
  695. # see https://github.com/joblib/joblib/issues/362
  696. filename = tmpdir.join('test.pkl').strpath
  697. test_array = np.zeros(10)
  698. numpy_pickle.dump(test_array, filename, protocol=pickle.HIGHEST_PROTOCOL)
  699. array_reloaded = numpy_pickle.load(filename)
  700. np.testing.assert_array_equal(array_reloaded, test_array)
  701. @with_numpy
  702. def test_pickle_in_socket():
  703. # test that joblib can pickle in sockets
  704. test_array = np.arange(10)
  705. _ADDR = ("localhost", 12345)
  706. listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  707. listener.bind(_ADDR)
  708. listener.listen(1)
  709. client = socket.create_connection(_ADDR)
  710. server, client_addr = listener.accept()
  711. with server.makefile("wb") as sf:
  712. numpy_pickle.dump(test_array, sf)
  713. with client.makefile("rb") as cf:
  714. array_reloaded = numpy_pickle.load(cf)
  715. np.testing.assert_array_equal(array_reloaded, test_array)
  716. @with_numpy
  717. def test_load_memmap_with_big_offset(tmpdir):
  718. # Test that numpy memmap offset is set correctly if greater than
  719. # mmap.ALLOCATIONGRANULARITY, see
  720. # https://github.com/joblib/joblib/issues/451 and
  721. # https://github.com/numpy/numpy/pull/8443 for more details.
  722. fname = tmpdir.join('test.mmap').strpath
  723. size = mmap.ALLOCATIONGRANULARITY
  724. obj = [np.zeros(size, dtype='uint8'), np.ones(size, dtype='uint8')]
  725. numpy_pickle.dump(obj, fname)
  726. memmaps = numpy_pickle.load(fname, mmap_mode='r')
  727. assert isinstance(memmaps[1], np.memmap)
  728. assert memmaps[1].offset > size
  729. np.testing.assert_array_equal(obj, memmaps)
  730. def test_register_compressor(tmpdir):
  731. # Check that registering compressor file works.
  732. compressor_name = 'test-name'
  733. compressor_prefix = 'test-prefix'
  734. class BinaryCompressorTestFile(io.BufferedIOBase):
  735. pass
  736. class BinaryCompressorTestWrapper(CompressorWrapper):
  737. def __init__(self):
  738. CompressorWrapper.__init__(self, obj=BinaryCompressorTestFile,
  739. prefix=compressor_prefix)
  740. register_compressor(compressor_name, BinaryCompressorTestWrapper())
  741. assert (_COMPRESSORS[compressor_name].fileobj_factory ==
  742. BinaryCompressorTestFile)
  743. assert _COMPRESSORS[compressor_name].prefix == compressor_prefix
  744. # Remove this dummy compressor file from extra compressors because other
  745. # tests might fail because of this
  746. _COMPRESSORS.pop(compressor_name)
  747. @parametrize('invalid_name', [1, (), {}])
  748. def test_register_compressor_invalid_name(invalid_name):
  749. # Test that registering an invalid compressor name is not allowed.
  750. with raises(ValueError) as excinfo:
  751. register_compressor(invalid_name, None)
  752. excinfo.match("Compressor name should be a string")
  753. def test_register_compressor_invalid_fileobj():
  754. # Test that registering an invalid file object is not allowed.
  755. class InvalidFileObject():
  756. pass
  757. class InvalidFileObjectWrapper(CompressorWrapper):
  758. def __init__(self):
  759. CompressorWrapper.__init__(self, obj=InvalidFileObject,
  760. prefix=b'prefix')
  761. with raises(ValueError) as excinfo:
  762. register_compressor('invalid', InvalidFileObjectWrapper())
  763. excinfo.match("Compressor 'fileobj_factory' attribute should implement "
  764. "the file object interface")
  765. class AnotherZlibCompressorWrapper(CompressorWrapper):
  766. def __init__(self):
  767. CompressorWrapper.__init__(self, obj=BinaryZlibFile, prefix=b'prefix')
  768. class StandardLibGzipCompressorWrapper(CompressorWrapper):
  769. def __init__(self):
  770. CompressorWrapper.__init__(self, obj=gzip.GzipFile, prefix=b'prefix')
  771. def test_register_compressor_already_registered():
  772. # Test registration of existing compressor files.
  773. compressor_name = 'test-name'
  774. # register a test compressor
  775. register_compressor(compressor_name, AnotherZlibCompressorWrapper())
  776. with raises(ValueError) as excinfo:
  777. register_compressor(compressor_name,
  778. StandardLibGzipCompressorWrapper())
  779. excinfo.match("Compressor '{}' already registered."
  780. .format(compressor_name))
  781. register_compressor(compressor_name, StandardLibGzipCompressorWrapper(),
  782. force=True)
  783. assert compressor_name in _COMPRESSORS
  784. assert _COMPRESSORS[compressor_name].fileobj_factory == gzip.GzipFile
  785. # Remove this dummy compressor file from extra compressors because other
  786. # tests might fail because of this
  787. _COMPRESSORS.pop(compressor_name)
  788. @with_lz4
  789. def test_lz4_compression(tmpdir):
  790. # Check that lz4 can be used when dependency is available.
  791. import lz4.frame
  792. compressor = 'lz4'
  793. assert compressor in _COMPRESSORS
  794. assert _COMPRESSORS[compressor].fileobj_factory == lz4.frame.LZ4FrameFile
  795. fname = tmpdir.join('test.pkl').strpath
  796. data = 'test data'
  797. numpy_pickle.dump(data, fname, compress=compressor)
  798. with open(fname, 'rb') as f:
  799. assert f.read(len(_LZ4_PREFIX)) == _LZ4_PREFIX
  800. assert numpy_pickle.load(fname) == data
  801. # Test that LZ4 is applied based on file extension
  802. numpy_pickle.dump(data, fname + '.lz4')
  803. with open(fname, 'rb') as f:
  804. assert f.read(len(_LZ4_PREFIX)) == _LZ4_PREFIX
  805. assert numpy_pickle.load(fname) == data
  806. @without_lz4
  807. def test_lz4_compression_without_lz4(tmpdir):
  808. # Check that lz4 cannot be used when dependency is not available.
  809. fname = tmpdir.join('test.nolz4').strpath
  810. data = 'test data'
  811. msg = LZ4_NOT_INSTALLED_ERROR
  812. with raises(ValueError) as excinfo:
  813. numpy_pickle.dump(data, fname, compress='lz4')
  814. excinfo.match(msg)
  815. with raises(ValueError) as excinfo:
  816. numpy_pickle.dump(data, fname + '.lz4')
  817. excinfo.match(msg)