create_numpy_pickle.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. """
  2. This script is used to generate test data for joblib/test/test_numpy_pickle.py
  3. """
  4. import sys
  5. import re
  6. # pytest needs to be able to import this module even when numpy is
  7. # not installed
  8. try:
  9. import numpy as np
  10. except ImportError:
  11. np = None
  12. import joblib
  13. def get_joblib_version(joblib_version=joblib.__version__):
  14. """Normalize joblib version by removing suffix.
  15. >>> get_joblib_version('0.8.4')
  16. '0.8.4'
  17. >>> get_joblib_version('0.8.4b1')
  18. '0.8.4'
  19. >>> get_joblib_version('0.9.dev0')
  20. '0.9'
  21. """
  22. matches = [re.match(r'(\d+).*', each)
  23. for each in joblib_version.split('.')]
  24. return '.'.join([m.group(1) for m in matches if m is not None])
  25. def write_test_pickle(to_pickle, args):
  26. kwargs = {}
  27. compress = args.compress
  28. method = args.method
  29. joblib_version = get_joblib_version()
  30. py_version = '{0[0]}{0[1]}'.format(sys.version_info)
  31. numpy_version = ''.join(np.__version__.split('.')[:2])
  32. # The game here is to generate the right filename according to the options.
  33. body = '_compressed' if (compress and method == 'zlib') else ''
  34. if compress:
  35. if method == 'zlib':
  36. kwargs['compress'] = True
  37. extension = '.gz'
  38. else:
  39. kwargs['compress'] = (method, 3)
  40. extension = '.pkl.{}'.format(method)
  41. if args.cache_size:
  42. kwargs['cache_size'] = 0
  43. body += '_cache_size'
  44. else:
  45. extension = '.pkl'
  46. pickle_filename = 'joblib_{}{}_pickle_py{}_np{}{}'.format(
  47. joblib_version, body, py_version, numpy_version, extension)
  48. try:
  49. joblib.dump(to_pickle, pickle_filename, **kwargs)
  50. except Exception as e:
  51. # With old python version (=< 3.3.), we can arrive there when
  52. # dumping compressed pickle with LzmaFile.
  53. print("Error: cannot generate file '{}' with arguments '{}'. "
  54. "Error was: {}".format(pickle_filename, kwargs, e))
  55. else:
  56. print("File '{}' generated successfuly.".format(pickle_filename))
  57. if __name__ == '__main__':
  58. import argparse
  59. parser = argparse.ArgumentParser(description="Joblib pickle data "
  60. "generator.")
  61. parser.add_argument('--cache_size', action="store_true",
  62. help="Force creation of companion numpy "
  63. "files for pickled arrays.")
  64. parser.add_argument('--compress', action="store_true",
  65. help="Generate compress pickles.")
  66. parser.add_argument('--method', type=str, default='zlib',
  67. choices=['zlib', 'gzip', 'bz2', 'xz', 'lzma', 'lz4'],
  68. help="Set compression method.")
  69. # We need to be specific about dtypes in particular endianness
  70. # because the pickles can be generated on one architecture and
  71. # the tests run on another one. See
  72. # https://github.com/joblib/joblib/issues/279.
  73. to_pickle = [np.arange(5, dtype=np.dtype('<i8')),
  74. np.arange(5, dtype=np.dtype('<f8')),
  75. np.array([1, 'abc', {'a': 1, 'b': 2}], dtype='O'),
  76. # all possible bytes as a byte string
  77. np.arange(256, dtype=np.uint8).tobytes(),
  78. np.matrix([0, 1, 2], dtype=np.dtype('<i8')),
  79. # unicode string with non-ascii chars
  80. u"C'est l'\xe9t\xe9 !"]
  81. write_test_pickle(to_pickle, parser.parse_args())