test_memory.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200
  1. """
  2. Test the memory module.
  3. """
  4. # Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
  5. # Copyright (c) 2009 Gael Varoquaux
  6. # License: BSD Style, 3 clauses.
  7. import gc
  8. import shutil
  9. import os
  10. import os.path
  11. import pathlib
  12. import pickle
  13. import sys
  14. import time
  15. import datetime
  16. import pytest
  17. from joblib.memory import Memory
  18. from joblib.memory import MemorizedFunc, NotMemorizedFunc
  19. from joblib.memory import MemorizedResult, NotMemorizedResult
  20. from joblib.memory import _FUNCTION_HASHES
  21. from joblib.memory import register_store_backend, _STORE_BACKENDS
  22. from joblib.memory import _build_func_identifier, _store_backend_factory
  23. from joblib.memory import JobLibCollisionWarning
  24. from joblib.parallel import Parallel, delayed
  25. from joblib._store_backends import StoreBackendBase, FileSystemStoreBackend
  26. from joblib.test.common import with_numpy, np
  27. from joblib.test.common import with_multiprocessing
  28. from joblib.testing import parametrize, raises, warns
  29. from joblib.hashing import hash
  30. ###############################################################################
  31. # Module-level variables for the tests
  32. def f(x, y=1):
  33. """ A module-level function for testing purposes.
  34. """
  35. return x ** 2 + y
  36. ###############################################################################
  37. # Helper function for the tests
  38. def check_identity_lazy(func, accumulator, location):
  39. """ Given a function and an accumulator (a list that grows every
  40. time the function is called), check that the function can be
  41. decorated by memory to be a lazy identity.
  42. """
  43. # Call each function with several arguments, and check that it is
  44. # evaluated only once per argument.
  45. memory = Memory(location=location, verbose=0)
  46. func = memory.cache(func)
  47. for i in range(3):
  48. for _ in range(2):
  49. assert func(i) == i
  50. assert len(accumulator) == i + 1
  51. def corrupt_single_cache_item(memory):
  52. single_cache_item, = memory.store_backend.get_items()
  53. output_filename = os.path.join(single_cache_item.path, 'output.pkl')
  54. with open(output_filename, 'w') as f:
  55. f.write('garbage')
  56. def monkeypatch_cached_func_warn(func, monkeypatch_fixture):
  57. # Need monkeypatch because pytest does not
  58. # capture stdlib logging output (see
  59. # https://github.com/pytest-dev/pytest/issues/2079)
  60. recorded = []
  61. def append_to_record(item):
  62. recorded.append(item)
  63. monkeypatch_fixture.setattr(func, 'warn', append_to_record)
  64. return recorded
  65. ###############################################################################
  66. # Tests
  67. def test_memory_integration(tmpdir):
  68. """ Simple test of memory lazy evaluation.
  69. """
  70. accumulator = list()
  71. # Rmk: this function has the same name than a module-level function,
  72. # thus it serves as a test to see that both are identified
  73. # as different.
  74. def f(l):
  75. accumulator.append(1)
  76. return l
  77. check_identity_lazy(f, accumulator, tmpdir.strpath)
  78. # Now test clearing
  79. for compress in (False, True):
  80. for mmap_mode in ('r', None):
  81. memory = Memory(location=tmpdir.strpath, verbose=10,
  82. mmap_mode=mmap_mode, compress=compress)
  83. # First clear the cache directory, to check that our code can
  84. # handle that
  85. # NOTE: this line would raise an exception, as the database file is
  86. # still open; we ignore the error since we want to test what
  87. # happens if the directory disappears
  88. shutil.rmtree(tmpdir.strpath, ignore_errors=True)
  89. g = memory.cache(f)
  90. g(1)
  91. g.clear(warn=False)
  92. current_accumulator = len(accumulator)
  93. out = g(1)
  94. assert len(accumulator) == current_accumulator + 1
  95. # Also, check that Memory.eval works similarly
  96. assert memory.eval(f, 1) == out
  97. assert len(accumulator) == current_accumulator + 1
  98. # Now do a smoke test with a function defined in __main__, as the name
  99. # mangling rules are more complex
  100. f.__module__ = '__main__'
  101. memory = Memory(location=tmpdir.strpath, verbose=0)
  102. memory.cache(f)(1)
  103. def test_no_memory():
  104. """ Test memory with location=None: no memoize """
  105. accumulator = list()
  106. def ff(l):
  107. accumulator.append(1)
  108. return l
  109. memory = Memory(location=None, verbose=0)
  110. gg = memory.cache(ff)
  111. for _ in range(4):
  112. current_accumulator = len(accumulator)
  113. gg(1)
  114. assert len(accumulator) == current_accumulator + 1
  115. def test_memory_kwarg(tmpdir):
  116. " Test memory with a function with keyword arguments."
  117. accumulator = list()
  118. def g(l=None, m=1):
  119. accumulator.append(1)
  120. return l
  121. check_identity_lazy(g, accumulator, tmpdir.strpath)
  122. memory = Memory(location=tmpdir.strpath, verbose=0)
  123. g = memory.cache(g)
  124. # Smoke test with an explicit keyword argument:
  125. assert g(l=30, m=2) == 30
  126. def test_memory_lambda(tmpdir):
  127. " Test memory with a function with a lambda."
  128. accumulator = list()
  129. def helper(x):
  130. """ A helper function to define l as a lambda.
  131. """
  132. accumulator.append(1)
  133. return x
  134. l = lambda x: helper(x)
  135. check_identity_lazy(l, accumulator, tmpdir.strpath)
  136. def test_memory_name_collision(tmpdir):
  137. " Check that name collisions with functions will raise warnings"
  138. memory = Memory(location=tmpdir.strpath, verbose=0)
  139. @memory.cache
  140. def name_collision(x):
  141. """ A first function called name_collision
  142. """
  143. return x
  144. a = name_collision
  145. @memory.cache
  146. def name_collision(x):
  147. """ A second function called name_collision
  148. """
  149. return x
  150. b = name_collision
  151. with warns(JobLibCollisionWarning) as warninfo:
  152. a(1)
  153. b(1)
  154. assert len(warninfo) == 1
  155. assert "collision" in str(warninfo[0].message)
  156. def test_memory_warning_lambda_collisions(tmpdir):
  157. # Check that multiple use of lambda will raise collisions
  158. memory = Memory(location=tmpdir.strpath, verbose=0)
  159. a = lambda x: x
  160. a = memory.cache(a)
  161. b = lambda x: x + 1
  162. b = memory.cache(b)
  163. with warns(JobLibCollisionWarning) as warninfo:
  164. assert a(0) == 0
  165. assert b(1) == 2
  166. assert a(1) == 1
  167. # In recent Python versions, we can retrieve the code of lambdas,
  168. # thus nothing is raised
  169. assert len(warninfo) == 4
  170. def test_memory_warning_collision_detection(tmpdir):
  171. # Check that collisions impossible to detect will raise appropriate
  172. # warnings.
  173. memory = Memory(location=tmpdir.strpath, verbose=0)
  174. a1 = eval('lambda x: x')
  175. a1 = memory.cache(a1)
  176. b1 = eval('lambda x: x+1')
  177. b1 = memory.cache(b1)
  178. with warns(JobLibCollisionWarning) as warninfo:
  179. a1(1)
  180. b1(1)
  181. a1(0)
  182. assert len(warninfo) == 2
  183. assert "cannot detect" in str(warninfo[0].message).lower()
  184. def test_memory_partial(tmpdir):
  185. " Test memory with functools.partial."
  186. accumulator = list()
  187. def func(x, y):
  188. """ A helper function to define l as a lambda.
  189. """
  190. accumulator.append(1)
  191. return y
  192. import functools
  193. function = functools.partial(func, 1)
  194. check_identity_lazy(function, accumulator, tmpdir.strpath)
  195. def test_memory_eval(tmpdir):
  196. " Smoke test memory with a function with a function defined in an eval."
  197. memory = Memory(location=tmpdir.strpath, verbose=0)
  198. m = eval('lambda x: x')
  199. mm = memory.cache(m)
  200. assert mm(1) == 1
  201. def count_and_append(x=[]):
  202. """ A function with a side effect in its arguments.
  203. Return the lenght of its argument and append one element.
  204. """
  205. len_x = len(x)
  206. x.append(None)
  207. return len_x
  208. def test_argument_change(tmpdir):
  209. """ Check that if a function has a side effect in its arguments, it
  210. should use the hash of changing arguments.
  211. """
  212. memory = Memory(location=tmpdir.strpath, verbose=0)
  213. func = memory.cache(count_and_append)
  214. # call the function for the first time, is should cache it with
  215. # argument x=[]
  216. assert func() == 0
  217. # the second time the argument is x=[None], which is not cached
  218. # yet, so the functions should be called a second time
  219. assert func() == 1
  220. @with_numpy
  221. @parametrize('mmap_mode', [None, 'r'])
  222. def test_memory_numpy(tmpdir, mmap_mode):
  223. " Test memory with a function with numpy arrays."
  224. accumulator = list()
  225. def n(l=None):
  226. accumulator.append(1)
  227. return l
  228. memory = Memory(location=tmpdir.strpath, mmap_mode=mmap_mode,
  229. verbose=0)
  230. cached_n = memory.cache(n)
  231. rnd = np.random.RandomState(0)
  232. for i in range(3):
  233. a = rnd.random_sample((10, 10))
  234. for _ in range(3):
  235. assert np.all(cached_n(a) == a)
  236. assert len(accumulator) == i + 1
  237. @with_numpy
  238. def test_memory_numpy_check_mmap_mode(tmpdir, monkeypatch):
  239. """Check that mmap_mode is respected even at the first call"""
  240. memory = Memory(location=tmpdir.strpath, mmap_mode='r', verbose=0)
  241. @memory.cache()
  242. def twice(a):
  243. return a * 2
  244. a = np.ones(3)
  245. b = twice(a)
  246. c = twice(a)
  247. assert isinstance(c, np.memmap)
  248. assert c.mode == 'r'
  249. assert isinstance(b, np.memmap)
  250. assert b.mode == 'r'
  251. # Corrupts the file, Deleting b and c mmaps
  252. # is necessary to be able edit the file
  253. del b
  254. del c
  255. gc.collect()
  256. corrupt_single_cache_item(memory)
  257. # Make sure that corrupting the file causes recomputation and that
  258. # a warning is issued.
  259. recorded_warnings = monkeypatch_cached_func_warn(twice, monkeypatch)
  260. d = twice(a)
  261. assert len(recorded_warnings) == 1
  262. exception_msg = 'Exception while loading results'
  263. assert exception_msg in recorded_warnings[0]
  264. # Asserts that the recomputation returns a mmap
  265. assert isinstance(d, np.memmap)
  266. assert d.mode == 'r'
  267. def test_memory_exception(tmpdir):
  268. """ Smoketest the exception handling of Memory.
  269. """
  270. memory = Memory(location=tmpdir.strpath, verbose=0)
  271. class MyException(Exception):
  272. pass
  273. @memory.cache
  274. def h(exc=0):
  275. if exc:
  276. raise MyException
  277. # Call once, to initialise the cache
  278. h()
  279. for _ in range(3):
  280. # Call 3 times, to be sure that the Exception is always raised
  281. with raises(MyException):
  282. h(1)
  283. def test_memory_ignore(tmpdir):
  284. " Test the ignore feature of memory "
  285. memory = Memory(location=tmpdir.strpath, verbose=0)
  286. accumulator = list()
  287. @memory.cache(ignore=['y'])
  288. def z(x, y=1):
  289. accumulator.append(1)
  290. assert z.ignore == ['y']
  291. z(0, y=1)
  292. assert len(accumulator) == 1
  293. z(0, y=1)
  294. assert len(accumulator) == 1
  295. z(0, y=2)
  296. assert len(accumulator) == 1
  297. def test_memory_args_as_kwargs(tmpdir):
  298. """Non-regression test against 0.12.0 changes.
  299. https://github.com/joblib/joblib/pull/751
  300. """
  301. memory = Memory(location=tmpdir.strpath, verbose=0)
  302. @memory.cache
  303. def plus_one(a):
  304. return a + 1
  305. # It's possible to call a positional arg as a kwarg.
  306. assert plus_one(1) == 2
  307. assert plus_one(a=1) == 2
  308. # However, a positional argument that joblib hadn't seen
  309. # before would cause a failure if it was passed as a kwarg.
  310. assert plus_one(a=2) == 3
  311. @parametrize('ignore, verbose, mmap_mode', [(['x'], 100, 'r'),
  312. ([], 10, None)])
  313. def test_partial_decoration(tmpdir, ignore, verbose, mmap_mode):
  314. "Check cache may be called with kwargs before decorating"
  315. memory = Memory(location=tmpdir.strpath, verbose=0)
  316. @memory.cache(ignore=ignore, verbose=verbose, mmap_mode=mmap_mode)
  317. def z(x):
  318. pass
  319. assert z.ignore == ignore
  320. assert z._verbose == verbose
  321. assert z.mmap_mode == mmap_mode
  322. def test_func_dir(tmpdir):
  323. # Test the creation of the memory cache directory for the function.
  324. memory = Memory(location=tmpdir.strpath, verbose=0)
  325. path = __name__.split('.')
  326. path.append('f')
  327. path = tmpdir.join('joblib', *path).strpath
  328. g = memory.cache(f)
  329. # Test that the function directory is created on demand
  330. func_id = _build_func_identifier(f)
  331. location = os.path.join(g.store_backend.location, func_id)
  332. assert location == path
  333. assert os.path.exists(path)
  334. assert memory.location == os.path.dirname(g.store_backend.location)
  335. with warns(DeprecationWarning) as w:
  336. assert memory.cachedir == g.store_backend.location
  337. assert len(w) == 1
  338. assert "The 'cachedir' attribute has been deprecated" in str(w[-1].message)
  339. # Test that the code is stored.
  340. # For the following test to be robust to previous execution, we clear
  341. # the in-memory store
  342. _FUNCTION_HASHES.clear()
  343. assert not g._check_previous_func_code()
  344. assert os.path.exists(os.path.join(path, 'func_code.py'))
  345. assert g._check_previous_func_code()
  346. # Test the robustness to failure of loading previous results.
  347. func_id, args_id = g._get_output_identifiers(1)
  348. output_dir = os.path.join(g.store_backend.location, func_id, args_id)
  349. a = g(1)
  350. assert os.path.exists(output_dir)
  351. os.remove(os.path.join(output_dir, 'output.pkl'))
  352. assert a == g(1)
  353. def test_persistence(tmpdir):
  354. # Test the memorized functions can be pickled and restored.
  355. memory = Memory(location=tmpdir.strpath, verbose=0)
  356. g = memory.cache(f)
  357. output = g(1)
  358. h = pickle.loads(pickle.dumps(g))
  359. func_id, args_id = h._get_output_identifiers(1)
  360. output_dir = os.path.join(h.store_backend.location, func_id, args_id)
  361. assert os.path.exists(output_dir)
  362. assert output == h.store_backend.load_item([func_id, args_id])
  363. memory2 = pickle.loads(pickle.dumps(memory))
  364. assert memory.store_backend.location == memory2.store_backend.location
  365. # Smoke test that pickling a memory with location=None works
  366. memory = Memory(location=None, verbose=0)
  367. pickle.loads(pickle.dumps(memory))
  368. g = memory.cache(f)
  369. gp = pickle.loads(pickle.dumps(g))
  370. gp(1)
  371. def test_call_and_shelve(tmpdir):
  372. # Test MemorizedFunc outputting a reference to cache.
  373. for func, Result in zip((MemorizedFunc(f, tmpdir.strpath),
  374. NotMemorizedFunc(f),
  375. Memory(location=tmpdir.strpath,
  376. verbose=0).cache(f),
  377. Memory(location=None).cache(f),
  378. ),
  379. (MemorizedResult, NotMemorizedResult,
  380. MemorizedResult, NotMemorizedResult)):
  381. assert func(2) == 5
  382. result = func.call_and_shelve(2)
  383. assert isinstance(result, Result)
  384. assert result.get() == 5
  385. result.clear()
  386. with raises(KeyError):
  387. result.get()
  388. result.clear() # Do nothing if there is no cache.
  389. def test_call_and_shelve_argument_hash(tmpdir):
  390. # Verify that a warning is raised when accessing arguments_hash
  391. # attribute from MemorizedResult
  392. func = Memory(location=tmpdir.strpath, verbose=0).cache(f)
  393. result = func.call_and_shelve(2)
  394. assert isinstance(result, MemorizedResult)
  395. with warns(DeprecationWarning) as w:
  396. assert result.argument_hash == result.args_id
  397. assert len(w) == 1
  398. assert "The 'argument_hash' attribute has been deprecated" \
  399. in str(w[-1].message)
  400. def test_call_and_shelve_lazily_load_stored_result(tmpdir):
  401. """Check call_and_shelve only load stored data if needed."""
  402. test_access_time_file = tmpdir.join('test_access')
  403. test_access_time_file.write('test_access')
  404. test_access_time = os.stat(test_access_time_file.strpath).st_atime
  405. # check file system access time stats resolution is lower than test wait
  406. # timings.
  407. time.sleep(0.5)
  408. assert test_access_time_file.read() == 'test_access'
  409. if test_access_time == os.stat(test_access_time_file.strpath).st_atime:
  410. # Skip this test when access time cannot be retrieved with enough
  411. # precision from the file system (e.g. NTFS on windows).
  412. pytest.skip("filesystem does not support fine-grained access time "
  413. "attribute")
  414. memory = Memory(location=tmpdir.strpath, verbose=0)
  415. func = memory.cache(f)
  416. func_id, argument_hash = func._get_output_identifiers(2)
  417. result_path = os.path.join(memory.store_backend.location,
  418. func_id, argument_hash, 'output.pkl')
  419. assert func(2) == 5
  420. first_access_time = os.stat(result_path).st_atime
  421. time.sleep(1)
  422. # Should not access the stored data
  423. result = func.call_and_shelve(2)
  424. assert isinstance(result, MemorizedResult)
  425. assert os.stat(result_path).st_atime == first_access_time
  426. time.sleep(1)
  427. # Read the stored data => last access time is greater than first_access
  428. assert result.get() == 5
  429. assert os.stat(result_path).st_atime > first_access_time
  430. def test_memorized_pickling(tmpdir):
  431. for func in (MemorizedFunc(f, tmpdir.strpath), NotMemorizedFunc(f)):
  432. filename = tmpdir.join('pickling_test.dat').strpath
  433. result = func.call_and_shelve(2)
  434. with open(filename, 'wb') as fp:
  435. pickle.dump(result, fp)
  436. with open(filename, 'rb') as fp:
  437. result2 = pickle.load(fp)
  438. assert result2.get() == result.get()
  439. os.remove(filename)
  440. def test_memorized_repr(tmpdir):
  441. func = MemorizedFunc(f, tmpdir.strpath)
  442. result = func.call_and_shelve(2)
  443. func2 = MemorizedFunc(f, tmpdir.strpath)
  444. result2 = func2.call_and_shelve(2)
  445. assert result.get() == result2.get()
  446. assert repr(func) == repr(func2)
  447. # Smoke test with NotMemorizedFunc
  448. func = NotMemorizedFunc(f)
  449. repr(func)
  450. repr(func.call_and_shelve(2))
  451. # Smoke test for message output (increase code coverage)
  452. func = MemorizedFunc(f, tmpdir.strpath, verbose=11, timestamp=time.time())
  453. result = func.call_and_shelve(11)
  454. result.get()
  455. func = MemorizedFunc(f, tmpdir.strpath, verbose=11)
  456. result = func.call_and_shelve(11)
  457. result.get()
  458. func = MemorizedFunc(f, tmpdir.strpath, verbose=5, timestamp=time.time())
  459. result = func.call_and_shelve(11)
  460. result.get()
  461. func = MemorizedFunc(f, tmpdir.strpath, verbose=5)
  462. result = func.call_and_shelve(11)
  463. result.get()
  464. def test_memory_file_modification(capsys, tmpdir, monkeypatch):
  465. # Test that modifying a Python file after loading it does not lead to
  466. # Recomputation
  467. dir_name = tmpdir.mkdir('tmp_import').strpath
  468. filename = os.path.join(dir_name, 'tmp_joblib_.py')
  469. content = 'def f(x):\n print(x)\n return x\n'
  470. with open(filename, 'w') as module_file:
  471. module_file.write(content)
  472. # Load the module:
  473. monkeypatch.syspath_prepend(dir_name)
  474. import tmp_joblib_ as tmp
  475. memory = Memory(location=tmpdir.strpath, verbose=0)
  476. f = memory.cache(tmp.f)
  477. # First call f a few times
  478. f(1)
  479. f(2)
  480. f(1)
  481. # Now modify the module where f is stored without modifying f
  482. with open(filename, 'w') as module_file:
  483. module_file.write('\n\n' + content)
  484. # And call f a couple more times
  485. f(1)
  486. f(1)
  487. # Flush the .pyc files
  488. shutil.rmtree(dir_name)
  489. os.mkdir(dir_name)
  490. # Now modify the module where f is stored, modifying f
  491. content = 'def f(x):\n print("x=%s" % x)\n return x\n'
  492. with open(filename, 'w') as module_file:
  493. module_file.write(content)
  494. # And call f more times prior to reloading: the cache should not be
  495. # invalidated at this point as the active function definition has not
  496. # changed in memory yet.
  497. f(1)
  498. f(1)
  499. # Now reload
  500. sys.stdout.write('Reloading\n')
  501. sys.modules.pop('tmp_joblib_')
  502. import tmp_joblib_ as tmp
  503. f = memory.cache(tmp.f)
  504. # And call f more times
  505. f(1)
  506. f(1)
  507. out, err = capsys.readouterr()
  508. assert out == '1\n2\nReloading\nx=1\n'
  509. def _function_to_cache(a, b):
  510. # Just a place holder function to be mutated by tests
  511. pass
  512. def _sum(a, b):
  513. return a + b
  514. def _product(a, b):
  515. return a * b
  516. def test_memory_in_memory_function_code_change(tmpdir):
  517. _function_to_cache.__code__ = _sum.__code__
  518. memory = Memory(location=tmpdir.strpath, verbose=0)
  519. f = memory.cache(_function_to_cache)
  520. assert f(1, 2) == 3
  521. assert f(1, 2) == 3
  522. with warns(JobLibCollisionWarning):
  523. # Check that inline function modification triggers a cache invalidation
  524. _function_to_cache.__code__ = _product.__code__
  525. assert f(1, 2) == 2
  526. assert f(1, 2) == 2
  527. def test_clear_memory_with_none_location():
  528. memory = Memory(location=None)
  529. memory.clear()
  530. def func_with_kwonly_args(a, b, *, kw1='kw1', kw2='kw2'):
  531. return a, b, kw1, kw2
  532. def func_with_signature(a: int, b: float) -> float:
  533. return a + b
  534. def test_memory_func_with_kwonly_args(tmpdir):
  535. memory = Memory(location=tmpdir.strpath, verbose=0)
  536. func_cached = memory.cache(func_with_kwonly_args)
  537. assert func_cached(1, 2, kw1=3) == (1, 2, 3, 'kw2')
  538. # Making sure that providing a keyword-only argument by
  539. # position raises an exception
  540. with raises(ValueError) as excinfo:
  541. func_cached(1, 2, 3, kw2=4)
  542. excinfo.match("Keyword-only parameter 'kw1' was passed as positional "
  543. "parameter")
  544. # Keyword-only parameter passed by position with cached call
  545. # should still raise ValueError
  546. func_cached(1, 2, kw1=3, kw2=4)
  547. with raises(ValueError) as excinfo:
  548. func_cached(1, 2, 3, kw2=4)
  549. excinfo.match("Keyword-only parameter 'kw1' was passed as positional "
  550. "parameter")
  551. # Test 'ignore' parameter
  552. func_cached = memory.cache(func_with_kwonly_args, ignore=['kw2'])
  553. assert func_cached(1, 2, kw1=3, kw2=4) == (1, 2, 3, 4)
  554. assert func_cached(1, 2, kw1=3, kw2='ignored') == (1, 2, 3, 4)
  555. def test_memory_func_with_signature(tmpdir):
  556. memory = Memory(location=tmpdir.strpath, verbose=0)
  557. func_cached = memory.cache(func_with_signature)
  558. assert func_cached(1, 2.) == 3.
  559. def _setup_toy_cache(tmpdir, num_inputs=10):
  560. memory = Memory(location=tmpdir.strpath, verbose=0)
  561. @memory.cache()
  562. def get_1000_bytes(arg):
  563. return 'a' * 1000
  564. inputs = list(range(num_inputs))
  565. for arg in inputs:
  566. get_1000_bytes(arg)
  567. func_id = _build_func_identifier(get_1000_bytes)
  568. hash_dirnames = [get_1000_bytes._get_output_identifiers(arg)[1]
  569. for arg in inputs]
  570. full_hashdirs = [os.path.join(get_1000_bytes.store_backend.location,
  571. func_id, dirname)
  572. for dirname in hash_dirnames]
  573. return memory, full_hashdirs, get_1000_bytes
  574. def test__get_items(tmpdir):
  575. memory, expected_hash_dirs, _ = _setup_toy_cache(tmpdir)
  576. items = memory.store_backend.get_items()
  577. hash_dirs = [ci.path for ci in items]
  578. assert set(hash_dirs) == set(expected_hash_dirs)
  579. def get_files_size(directory):
  580. full_paths = [os.path.join(directory, fn)
  581. for fn in os.listdir(directory)]
  582. return sum(os.path.getsize(fp) for fp in full_paths)
  583. expected_hash_cache_sizes = [get_files_size(hash_dir)
  584. for hash_dir in hash_dirs]
  585. hash_cache_sizes = [ci.size for ci in items]
  586. assert hash_cache_sizes == expected_hash_cache_sizes
  587. output_filenames = [os.path.join(hash_dir, 'output.pkl')
  588. for hash_dir in hash_dirs]
  589. expected_last_accesses = [
  590. datetime.datetime.fromtimestamp(os.path.getatime(fn))
  591. for fn in output_filenames]
  592. last_accesses = [ci.last_access for ci in items]
  593. assert last_accesses == expected_last_accesses
  594. def test__get_items_to_delete(tmpdir):
  595. memory, expected_hash_cachedirs, _ = _setup_toy_cache(tmpdir)
  596. items = memory.store_backend.get_items()
  597. # bytes_limit set to keep only one cache item (each hash cache
  598. # folder is about 1000 bytes + metadata)
  599. items_to_delete = memory.store_backend._get_items_to_delete('2K')
  600. nb_hashes = len(expected_hash_cachedirs)
  601. assert set.issubset(set(items_to_delete), set(items))
  602. assert len(items_to_delete) == nb_hashes - 1
  603. # Sanity check bytes_limit=2048 is the same as bytes_limit='2K'
  604. items_to_delete_2048b = memory.store_backend._get_items_to_delete(2048)
  605. assert sorted(items_to_delete) == sorted(items_to_delete_2048b)
  606. # bytes_limit greater than the size of the cache
  607. items_to_delete_empty = memory.store_backend._get_items_to_delete('1M')
  608. assert items_to_delete_empty == []
  609. # All the cache items need to be deleted
  610. bytes_limit_too_small = 500
  611. items_to_delete_500b = memory.store_backend._get_items_to_delete(
  612. bytes_limit_too_small)
  613. assert set(items_to_delete_500b), set(items)
  614. # Test LRU property: surviving cache items should all have a more
  615. # recent last_access that the ones that have been deleted
  616. items_to_delete_6000b = memory.store_backend._get_items_to_delete(6000)
  617. surviving_items = set(items).difference(items_to_delete_6000b)
  618. assert (max(ci.last_access for ci in items_to_delete_6000b) <=
  619. min(ci.last_access for ci in surviving_items))
  620. def test_memory_reduce_size(tmpdir):
  621. memory, _, _ = _setup_toy_cache(tmpdir)
  622. ref_cache_items = memory.store_backend.get_items()
  623. # By default memory.bytes_limit is None and reduce_size is a noop
  624. memory.reduce_size()
  625. cache_items = memory.store_backend.get_items()
  626. assert sorted(ref_cache_items) == sorted(cache_items)
  627. # No cache items deleted if bytes_limit greater than the size of
  628. # the cache
  629. memory.bytes_limit = '1M'
  630. memory.reduce_size()
  631. cache_items = memory.store_backend.get_items()
  632. assert sorted(ref_cache_items) == sorted(cache_items)
  633. # bytes_limit is set so that only two cache items are kept
  634. memory.bytes_limit = '3K'
  635. memory.reduce_size()
  636. cache_items = memory.store_backend.get_items()
  637. assert set.issubset(set(cache_items), set(ref_cache_items))
  638. assert len(cache_items) == 2
  639. # bytes_limit set so that no cache item is kept
  640. bytes_limit_too_small = 500
  641. memory.bytes_limit = bytes_limit_too_small
  642. memory.reduce_size()
  643. cache_items = memory.store_backend.get_items()
  644. assert cache_items == []
  645. def test_memory_clear(tmpdir):
  646. memory, _, _ = _setup_toy_cache(tmpdir)
  647. memory.clear()
  648. assert os.listdir(memory.store_backend.location) == []
  649. def fast_func_with_complex_output():
  650. complex_obj = ['a' * 1000] * 1000
  651. return complex_obj
  652. def fast_func_with_conditional_complex_output(complex_output=True):
  653. complex_obj = {str(i): i for i in range(int(1e5))}
  654. return complex_obj if complex_output else 'simple output'
  655. @with_multiprocessing
  656. def test_cached_function_race_condition_when_persisting_output(tmpdir, capfd):
  657. # Test race condition where multiple processes are writing into
  658. # the same output.pkl. See
  659. # https://github.com/joblib/joblib/issues/490 for more details.
  660. memory = Memory(location=tmpdir.strpath)
  661. func_cached = memory.cache(fast_func_with_complex_output)
  662. Parallel(n_jobs=2)(delayed(func_cached)() for i in range(3))
  663. stdout, stderr = capfd.readouterr()
  664. # Checking both stdout and stderr (ongoing PR #434 may change
  665. # logging destination) to make sure there is no exception while
  666. # loading the results
  667. exception_msg = 'Exception while loading results'
  668. assert exception_msg not in stdout
  669. assert exception_msg not in stderr
  670. @with_multiprocessing
  671. def test_cached_function_race_condition_when_persisting_output_2(tmpdir,
  672. capfd):
  673. # Test race condition in first attempt at solving
  674. # https://github.com/joblib/joblib/issues/490. The race condition
  675. # was due to the delay between seeing the cache directory created
  676. # (interpreted as the result being cached) and the output.pkl being
  677. # pickled.
  678. memory = Memory(location=tmpdir.strpath)
  679. func_cached = memory.cache(fast_func_with_conditional_complex_output)
  680. Parallel(n_jobs=2)(delayed(func_cached)(True if i % 2 == 0 else False)
  681. for i in range(3))
  682. stdout, stderr = capfd.readouterr()
  683. # Checking both stdout and stderr (ongoing PR #434 may change
  684. # logging destination) to make sure there is no exception while
  685. # loading the results
  686. exception_msg = 'Exception while loading results'
  687. assert exception_msg not in stdout
  688. assert exception_msg not in stderr
  689. def test_memory_recomputes_after_an_error_while_loading_results(
  690. tmpdir, monkeypatch):
  691. memory = Memory(location=tmpdir.strpath)
  692. def func(arg):
  693. # This makes sure that the timestamp returned by two calls of
  694. # func are different. This is needed on Windows where
  695. # time.time resolution may not be accurate enough
  696. time.sleep(0.01)
  697. return arg, time.time()
  698. cached_func = memory.cache(func)
  699. input_arg = 'arg'
  700. arg, timestamp = cached_func(input_arg)
  701. # Make sure the function is correctly cached
  702. assert arg == input_arg
  703. # Corrupting output.pkl to make sure that an error happens when
  704. # loading the cached result
  705. corrupt_single_cache_item(memory)
  706. # Make sure that corrupting the file causes recomputation and that
  707. # a warning is issued.
  708. recorded_warnings = monkeypatch_cached_func_warn(cached_func, monkeypatch)
  709. recomputed_arg, recomputed_timestamp = cached_func(arg)
  710. assert len(recorded_warnings) == 1
  711. exception_msg = 'Exception while loading results'
  712. assert exception_msg in recorded_warnings[0]
  713. assert recomputed_arg == arg
  714. assert recomputed_timestamp > timestamp
  715. # Corrupting output.pkl to make sure that an error happens when
  716. # loading the cached result
  717. corrupt_single_cache_item(memory)
  718. reference = cached_func.call_and_shelve(arg)
  719. try:
  720. reference.get()
  721. raise AssertionError(
  722. "It normally not possible to load a corrupted"
  723. " MemorizedResult"
  724. )
  725. except KeyError as e:
  726. message = "is corrupted"
  727. assert message in str(e.args)
  728. def test_deprecated_cachedir_behaviour(tmpdir):
  729. # verify the right deprecation warnings are raised when using cachedir
  730. # option instead of new location parameter.
  731. with warns(None) as w:
  732. memory = Memory(cachedir=tmpdir.strpath, verbose=0)
  733. assert memory.store_backend.location.startswith(tmpdir.strpath)
  734. assert len(w) == 1
  735. assert "The 'cachedir' parameter has been deprecated" in str(w[-1].message)
  736. with warns(None) as w:
  737. memory = Memory()
  738. assert memory.cachedir is None
  739. assert len(w) == 1
  740. assert "The 'cachedir' attribute has been deprecated" in str(w[-1].message)
  741. error_regex = """You set both "location='.+ and "cachedir='.+"""
  742. with raises(ValueError, match=error_regex):
  743. memory = Memory(location=tmpdir.strpath, cachedir=tmpdir.strpath,
  744. verbose=0)
  745. class IncompleteStoreBackend(StoreBackendBase):
  746. """This backend cannot be instanciated and should raise a TypeError."""
  747. pass
  748. class DummyStoreBackend(StoreBackendBase):
  749. """A dummy store backend that does nothing."""
  750. def _open_item(self, *args, **kwargs):
  751. """Open an item on store."""
  752. "Does nothing"
  753. def _item_exists(self, location):
  754. """Check if an item location exists."""
  755. "Does nothing"
  756. def _move_item(self, src, dst):
  757. """Move an item from src to dst in store."""
  758. "Does nothing"
  759. def create_location(self, location):
  760. """Create location on store."""
  761. "Does nothing"
  762. def exists(self, obj):
  763. """Check if an object exists in the store"""
  764. return False
  765. def clear_location(self, obj):
  766. """Clear object on store"""
  767. "Does nothing"
  768. def get_items(self):
  769. """Returns the whole list of items available in cache."""
  770. return []
  771. def configure(self, location, *args, **kwargs):
  772. """Configure the store"""
  773. "Does nothing"
  774. @parametrize("invalid_prefix", [None, dict(), list()])
  775. def test_register_invalid_store_backends_key(invalid_prefix):
  776. # verify the right exceptions are raised when passing a wrong backend key.
  777. with raises(ValueError) as excinfo:
  778. register_store_backend(invalid_prefix, None)
  779. excinfo.match(r'Store backend name should be a string*')
  780. def test_register_invalid_store_backends_object():
  781. # verify the right exceptions are raised when passing a wrong backend
  782. # object.
  783. with raises(ValueError) as excinfo:
  784. register_store_backend("fs", None)
  785. excinfo.match(r'Store backend should inherit StoreBackendBase*')
  786. def test_memory_default_store_backend():
  787. # test an unknow backend falls back into a FileSystemStoreBackend
  788. with raises(TypeError) as excinfo:
  789. Memory(location='/tmp/joblib', backend='unknown')
  790. excinfo.match(r"Unknown location*")
  791. def test_warning_on_unknown_location_type():
  792. class NonSupportedLocationClass:
  793. pass
  794. unsupported_location = NonSupportedLocationClass()
  795. with warns(UserWarning) as warninfo:
  796. _store_backend_factory("local", location=unsupported_location)
  797. expected_mesage = ("Instanciating a backend using a "
  798. "NonSupportedLocationClass as a location is not "
  799. "supported by joblib")
  800. assert expected_mesage in str(warninfo[0].message)
  801. def test_instanciate_incomplete_store_backend():
  802. # Verify that registering an external incomplete store backend raises an
  803. # exception when one tries to instanciate it.
  804. backend_name = "isb"
  805. register_store_backend(backend_name, IncompleteStoreBackend)
  806. assert (backend_name, IncompleteStoreBackend) in _STORE_BACKENDS.items()
  807. with raises(TypeError) as excinfo:
  808. _store_backend_factory(backend_name, "fake_location")
  809. excinfo.match(r"Can't instantiate abstract class "
  810. "IncompleteStoreBackend with abstract methods*")
  811. def test_dummy_store_backend():
  812. # Verify that registering an external store backend works.
  813. backend_name = "dsb"
  814. register_store_backend(backend_name, DummyStoreBackend)
  815. assert (backend_name, DummyStoreBackend) in _STORE_BACKENDS.items()
  816. backend_obj = _store_backend_factory(backend_name, "dummy_location")
  817. assert isinstance(backend_obj, DummyStoreBackend)
  818. def test_instanciate_store_backend_with_pathlib_path():
  819. # Instanciate a FileSystemStoreBackend using a pathlib.Path object
  820. path = pathlib.Path("some_folder")
  821. backend_obj = _store_backend_factory("local", path)
  822. assert backend_obj.location == "some_folder"
  823. def test_filesystem_store_backend_repr(tmpdir):
  824. # Verify string representation of a filesystem store backend.
  825. repr_pattern = 'FileSystemStoreBackend(location="{location}")'
  826. backend = FileSystemStoreBackend()
  827. assert backend.location is None
  828. repr(backend) # Should not raise an exception
  829. assert str(backend) == repr_pattern.format(location=None)
  830. # backend location is passed explicitely via the configure method (called
  831. # by the internal _store_backend_factory function)
  832. backend.configure(tmpdir.strpath)
  833. assert str(backend) == repr_pattern.format(location=tmpdir.strpath)
  834. repr(backend) # Should not raise an exception
  835. def test_memory_objects_repr(tmpdir):
  836. # Verify printable reprs of MemorizedResult, MemorizedFunc and Memory.
  837. def my_func(a, b):
  838. return a + b
  839. memory = Memory(location=tmpdir.strpath, verbose=0)
  840. memorized_func = memory.cache(my_func)
  841. memorized_func_repr = 'MemorizedFunc(func={func}, location={location})'
  842. assert str(memorized_func) == memorized_func_repr.format(
  843. func=my_func,
  844. location=memory.store_backend.location)
  845. memorized_result = memorized_func.call_and_shelve(42, 42)
  846. memorized_result_repr = ('MemorizedResult(location="{location}", '
  847. 'func="{func}", args_id="{args_id}")')
  848. assert str(memorized_result) == memorized_result_repr.format(
  849. location=memory.store_backend.location,
  850. func=memorized_result.func_id,
  851. args_id=memorized_result.args_id)
  852. assert str(memory) == 'Memory(location={location})'.format(
  853. location=memory.store_backend.location)
  854. def test_memorized_result_pickle(tmpdir):
  855. # Verify a MemoryResult object can be pickled/depickled. Non regression
  856. # test introduced following issue
  857. # https://github.com/joblib/joblib/issues/747
  858. memory = Memory(location=tmpdir.strpath)
  859. @memory.cache
  860. def g(x):
  861. return x**2
  862. memorized_result = g.call_and_shelve(4)
  863. memorized_result_pickle = pickle.dumps(memorized_result)
  864. memorized_result_loads = pickle.loads(memorized_result_pickle)
  865. assert memorized_result.store_backend.location == \
  866. memorized_result_loads.store_backend.location
  867. assert memorized_result.func == memorized_result_loads.func
  868. assert memorized_result.args_id == memorized_result_loads.args_id
  869. assert str(memorized_result) == str(memorized_result_loads)
  870. def compare(left, right, ignored_attrs=None):
  871. if ignored_attrs is None:
  872. ignored_attrs = []
  873. left_vars = vars(left)
  874. right_vars = vars(right)
  875. assert set(left_vars.keys()) == set(right_vars.keys())
  876. for attr in left_vars.keys():
  877. if attr in ignored_attrs:
  878. continue
  879. assert left_vars[attr] == right_vars[attr]
  880. @pytest.mark.parametrize('memory_kwargs',
  881. [{'compress': 3, 'verbose': 2},
  882. {'mmap_mode': 'r', 'verbose': 5, 'bytes_limit': 1e6,
  883. 'backend_options': {'parameter': 'unused'}}])
  884. def test_memory_pickle_dump_load(tmpdir, memory_kwargs):
  885. memory = Memory(location=tmpdir.strpath, **memory_kwargs)
  886. memory_reloaded = pickle.loads(pickle.dumps(memory))
  887. # Compare Memory instance before and after pickle roundtrip
  888. compare(memory.store_backend, memory_reloaded.store_backend)
  889. compare(memory, memory_reloaded,
  890. ignored_attrs=set(['store_backend', 'timestamp']))
  891. assert hash(memory) == hash(memory_reloaded)
  892. func_cached = memory.cache(f)
  893. func_cached_reloaded = pickle.loads(pickle.dumps(func_cached))
  894. # Compare MemorizedFunc instance before/after pickle roundtrip
  895. compare(func_cached.store_backend, func_cached_reloaded.store_backend)
  896. compare(func_cached, func_cached_reloaded,
  897. ignored_attrs=set(['store_backend', 'timestamp']))
  898. assert hash(func_cached) == hash(func_cached_reloaded)
  899. # Compare MemorizedResult instance before/after pickle roundtrip
  900. memorized_result = func_cached.call_and_shelve(1)
  901. memorized_result_reloaded = pickle.loads(pickle.dumps(memorized_result))
  902. compare(memorized_result.store_backend,
  903. memorized_result_reloaded.store_backend)
  904. compare(memorized_result, memorized_result_reloaded,
  905. ignored_attrs=set(['store_backend', 'timestamp']))
  906. assert hash(memorized_result) == hash(memorized_result_reloaded)