test_dask.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. from __future__ import print_function, division, absolute_import
  2. import os
  3. import pytest
  4. from random import random
  5. from uuid import uuid4
  6. from time import sleep
  7. from .. import Parallel, delayed, parallel_backend
  8. from ..parallel import ThreadingBackend, AutoBatchingMixin
  9. from .._dask import DaskDistributedBackend
  10. distributed = pytest.importorskip('distributed')
  11. from distributed import Client, LocalCluster, get_client
  12. from distributed.metrics import time
  13. from distributed.utils_test import cluster, inc
  14. def noop(*args, **kwargs):
  15. pass
  16. def slow_raise_value_error(condition, duration=0.05):
  17. sleep(duration)
  18. if condition:
  19. raise ValueError("condition evaluated to True")
  20. def count_events(event_name, client):
  21. worker_events = client.run(lambda dask_worker: dask_worker.log)
  22. event_counts = {}
  23. for w, events in worker_events.items():
  24. event_counts[w] = len([event for event in list(events)
  25. if event[1] == event_name])
  26. return event_counts
  27. def test_simple(loop):
  28. with cluster() as (s, [a, b]):
  29. with Client(s['address'], loop=loop) as client: # noqa: F841
  30. with parallel_backend('dask') as (ba, _):
  31. seq = Parallel()(delayed(inc)(i) for i in range(10))
  32. assert seq == [inc(i) for i in range(10)]
  33. with pytest.raises(ValueError):
  34. Parallel()(delayed(slow_raise_value_error)(i == 3)
  35. for i in range(10))
  36. seq = Parallel()(delayed(inc)(i) for i in range(10))
  37. assert seq == [inc(i) for i in range(10)]
  38. def test_dask_backend_uses_autobatching(loop):
  39. assert (DaskDistributedBackend.compute_batch_size
  40. is AutoBatchingMixin.compute_batch_size)
  41. with cluster() as (s, [a, b]):
  42. with Client(s['address'], loop=loop) as client: # noqa: F841
  43. with parallel_backend('dask') as (ba, _):
  44. with Parallel() as parallel:
  45. # The backend should be initialized with a default
  46. # batch size of 1:
  47. backend = parallel._backend
  48. assert isinstance(backend, DaskDistributedBackend)
  49. assert backend.parallel is parallel
  50. assert backend._effective_batch_size == 1
  51. # Launch many short tasks that should trigger
  52. # auto-batching:
  53. parallel(
  54. delayed(lambda: None)()
  55. for _ in range(int(1e4))
  56. )
  57. assert backend._effective_batch_size > 10
  58. def random2():
  59. return random()
  60. def test_dont_assume_function_purity(loop):
  61. with cluster() as (s, [a, b]):
  62. with Client(s['address'], loop=loop) as client: # noqa: F841
  63. with parallel_backend('dask') as (ba, _):
  64. x, y = Parallel()(delayed(random2)() for i in range(2))
  65. assert x != y
  66. @pytest.mark.parametrize("mixed", [True, False])
  67. def test_dask_funcname(loop, mixed):
  68. from joblib._dask import Batch
  69. if not mixed:
  70. tasks = [delayed(inc)(i) for i in range(4)]
  71. batch_repr = 'batch_of_inc_4_calls'
  72. else:
  73. tasks = [
  74. delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)
  75. ]
  76. batch_repr = 'mixed_batch_of_inc_4_calls'
  77. assert repr(Batch(tasks)) == batch_repr
  78. with cluster() as (s, [a, b]):
  79. with Client(s['address'], loop=loop) as client:
  80. with parallel_backend('dask') as (ba, _):
  81. _ = Parallel(batch_size=2, pre_dispatch='all')(tasks)
  82. def f(dask_scheduler):
  83. return list(dask_scheduler.transition_log)
  84. batch_repr = batch_repr.replace('4', '2')
  85. log = client.run_on_scheduler(f)
  86. assert all('batch_of_inc' in tup[0] for tup in log)
  87. def test_no_undesired_distributed_cache_hit(loop):
  88. # Dask has a pickle cache for callables that are called many times. Because
  89. # the dask backends used to wrapp both the functions and the arguments
  90. # under instances of the Batch callable class this caching mechanism could
  91. # lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
  92. # The joblib-dask backend has been refactored to avoid bundling the
  93. # arguments as an attribute of the Batch instance to avoid this problem.
  94. # This test serves as non-regression problem.
  95. # Use a large number of input arguments to give the AutoBatchingMixin
  96. # enough tasks to kick-in.
  97. lists = [[] for _ in range(100)]
  98. np = pytest.importorskip('numpy')
  99. X = np.arange(int(1e6))
  100. def isolated_operation(list_, X=None):
  101. list_.append(uuid4().hex)
  102. return list_
  103. cluster = LocalCluster(n_workers=1, threads_per_worker=2)
  104. client = Client(cluster)
  105. try:
  106. with parallel_backend('dask') as (ba, _):
  107. # dispatches joblib.parallel.BatchedCalls
  108. res = Parallel()(
  109. delayed(isolated_operation)(list_) for list_ in lists
  110. )
  111. # The original arguments should not have been mutated as the mutation
  112. # happens in the dask worker process.
  113. assert lists == [[] for _ in range(100)]
  114. # Here we did not pass any large numpy array as argument to
  115. # isolated_operation so no scattering event should happen under the
  116. # hood.
  117. counts = count_events('receive-from-scatter', client)
  118. assert sum(counts.values()) == 0
  119. assert all([len(r) == 1 for r in res])
  120. with parallel_backend('dask') as (ba, _):
  121. # Append a large array which will be scattered by dask, and
  122. # dispatch joblib._dask.Batch
  123. res = Parallel()(
  124. delayed(isolated_operation)(list_, X=X) for list_ in lists
  125. )
  126. # This time, auto-scattering should have kicked it.
  127. counts = count_events('receive-from-scatter', client)
  128. assert sum(counts.values()) > 0
  129. assert all([len(r) == 1 for r in res])
  130. finally:
  131. client.close()
  132. cluster.close()
  133. class CountSerialized(object):
  134. def __init__(self, x):
  135. self.x = x
  136. self.count = 0
  137. def __add__(self, other):
  138. return self.x + getattr(other, 'x', other)
  139. __radd__ = __add__
  140. def __reduce__(self):
  141. self.count += 1
  142. return (CountSerialized, (self.x,))
  143. def add5(a, b, c, d=0, e=0):
  144. return a + b + c + d + e
  145. def test_manual_scatter(loop):
  146. x = CountSerialized(1)
  147. y = CountSerialized(2)
  148. z = CountSerialized(3)
  149. with cluster() as (s, [a, b]):
  150. with Client(s['address'], loop=loop) as client: # noqa: F841
  151. with parallel_backend('dask', scatter=[x, y]) as (ba, _):
  152. f = delayed(add5)
  153. tasks = [f(x, y, z, d=4, e=5),
  154. f(x, z, y, d=5, e=4),
  155. f(y, x, z, d=x, e=5),
  156. f(z, z, x, d=z, e=y)]
  157. expected = [func(*args, **kwargs)
  158. for func, args, kwargs in tasks]
  159. results = Parallel()(tasks)
  160. # Scatter must take a list/tuple
  161. with pytest.raises(TypeError):
  162. with parallel_backend('dask', loop=loop, scatter=1):
  163. pass
  164. assert results == expected
  165. # Scattered variables only serialized once
  166. assert x.count == 1
  167. assert y.count == 1
  168. # Depending on the version of distributed, the unscattered z variable
  169. # is either pickled 4 or 6 times, possibly because of the memoization
  170. # of objects that appear several times in the arguments of a delayed
  171. # task.
  172. assert z.count in (4, 6)
  173. def test_auto_scatter(loop):
  174. np = pytest.importorskip('numpy')
  175. data1 = np.ones(int(1e4), dtype=np.uint8)
  176. data2 = np.ones(int(1e4), dtype=np.uint8)
  177. data_to_process = ([data1] * 3) + ([data2] * 3)
  178. with cluster() as (s, [a, b]):
  179. with Client(s['address'], loop=loop) as client:
  180. with parallel_backend('dask') as (ba, _):
  181. # Passing the same data as arg and kwarg triggers a single
  182. # scatter operation whose result is reused.
  183. Parallel()(delayed(noop)(data, data, i, opt=data)
  184. for i, data in enumerate(data_to_process))
  185. # By default large array are automatically scattered with
  186. # broadcast=1 which means that one worker must directly receive
  187. # the data from the scatter operation once.
  188. counts = count_events('receive-from-scatter', client)
  189. # assert counts[a['address']] + counts[b['address']] == 2
  190. assert 2 <= counts[a['address']] + counts[b['address']] <= 4
  191. with cluster() as (s, [a, b]):
  192. with Client(s['address'], loop=loop) as client:
  193. with parallel_backend('dask') as (ba, _):
  194. Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
  195. # Small arrays are passed within the task definition without going
  196. # through a scatter operation.
  197. counts = count_events('receive-from-scatter', client)
  198. assert counts[a['address']] == 0
  199. assert counts[b['address']] == 0
  200. @pytest.mark.parametrize("retry_no", list(range(2)))
  201. def test_nested_scatter(loop, retry_no):
  202. np = pytest.importorskip('numpy')
  203. NUM_INNER_TASKS = 10
  204. NUM_OUTER_TASKS = 10
  205. def my_sum(x, i, j):
  206. return np.sum(x)
  207. def outer_function_joblib(array, i):
  208. client = get_client() # noqa
  209. with parallel_backend("dask"):
  210. results = Parallel()(
  211. delayed(my_sum)(array[j:], i, j) for j in range(
  212. NUM_INNER_TASKS)
  213. )
  214. return sum(results)
  215. with cluster() as (s, [a, b]):
  216. with Client(s['address'], loop=loop) as _:
  217. with parallel_backend("dask"):
  218. my_array = np.ones(10000)
  219. _ = Parallel()(
  220. delayed(outer_function_joblib)(
  221. my_array[i:], i) for i in range(NUM_OUTER_TASKS)
  222. )
  223. def test_nested_backend_context_manager(loop):
  224. def get_nested_pids():
  225. pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
  226. pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
  227. return pids
  228. with cluster() as (s, [a, b]):
  229. with Client(s['address'], loop=loop) as client:
  230. with parallel_backend('dask') as (ba, _):
  231. pid_groups = Parallel(n_jobs=2)(
  232. delayed(get_nested_pids)()
  233. for _ in range(10)
  234. )
  235. for pid_group in pid_groups:
  236. assert len(set(pid_group)) <= 2
  237. # No deadlocks
  238. with Client(s['address'], loop=loop) as client: # noqa: F841
  239. with parallel_backend('dask') as (ba, _):
  240. pid_groups = Parallel(n_jobs=2)(
  241. delayed(get_nested_pids)()
  242. for _ in range(10)
  243. )
  244. for pid_group in pid_groups:
  245. assert len(set(pid_group)) <= 2
  246. def test_nested_backend_context_manager_implicit_n_jobs(loop):
  247. # Check that Parallel with no explicit n_jobs value automatically selects
  248. # all the dask workers, including in nested calls.
  249. def _backend_type(p):
  250. return p._backend.__class__.__name__
  251. def get_nested_implicit_n_jobs():
  252. with Parallel() as p:
  253. return _backend_type(p), p.n_jobs
  254. with cluster() as (s, [a, b]):
  255. with Client(s['address'], loop=loop) as client: # noqa: F841
  256. with parallel_backend('dask') as (ba, _):
  257. with Parallel() as p:
  258. assert _backend_type(p) == "DaskDistributedBackend"
  259. assert p.n_jobs == -1
  260. all_nested_n_jobs = p(
  261. delayed(get_nested_implicit_n_jobs)()
  262. for _ in range(2)
  263. )
  264. for backend_type, nested_n_jobs in all_nested_n_jobs:
  265. assert backend_type == "DaskDistributedBackend"
  266. assert nested_n_jobs == -1
  267. def test_errors(loop):
  268. with pytest.raises(ValueError) as info:
  269. with parallel_backend('dask'):
  270. pass
  271. assert "create a dask client" in str(info.value).lower()
  272. def test_correct_nested_backend(loop):
  273. with cluster() as (s, [a, b]):
  274. with Client(s['address'], loop=loop) as client: # noqa: F841
  275. # No requirement, should be us
  276. with parallel_backend('dask') as (ba, _):
  277. result = Parallel(n_jobs=2)(
  278. delayed(outer)(nested_require=None) for _ in range(1))
  279. assert isinstance(result[0][0][0], DaskDistributedBackend)
  280. # Require threads, should be threading
  281. with parallel_backend('dask') as (ba, _):
  282. result = Parallel(n_jobs=2)(
  283. delayed(outer)(nested_require='sharedmem')
  284. for _ in range(1))
  285. assert isinstance(result[0][0][0], ThreadingBackend)
  286. def outer(nested_require):
  287. return Parallel(n_jobs=2, prefer='threads')(
  288. delayed(middle)(nested_require) for _ in range(1)
  289. )
  290. def middle(require):
  291. return Parallel(n_jobs=2, require=require)(
  292. delayed(inner)() for _ in range(1)
  293. )
  294. def inner():
  295. return Parallel()._backend
  296. def test_secede_with_no_processes(loop):
  297. # https://github.com/dask/distributed/issues/1775
  298. with Client(loop=loop, processes=False, set_as_default=True):
  299. with parallel_backend('dask'):
  300. Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
  301. def _worker_address(_):
  302. from distributed import get_worker
  303. return get_worker().address
  304. def test_dask_backend_keywords(loop):
  305. with cluster() as (s, [a, b]):
  306. with Client(s['address'], loop=loop) as client: # noqa: F841
  307. with parallel_backend('dask', workers=a['address']) as (ba, _):
  308. seq = Parallel()(
  309. delayed(_worker_address)(i) for i in range(10))
  310. assert seq == [a['address']] * 10
  311. with parallel_backend('dask', workers=b['address']) as (ba, _):
  312. seq = Parallel()(
  313. delayed(_worker_address)(i) for i in range(10))
  314. assert seq == [b['address']] * 10
  315. def test_cleanup(loop):
  316. with Client(processes=False, loop=loop) as client:
  317. with parallel_backend('dask'):
  318. Parallel()(delayed(inc)(i) for i in range(10))
  319. start = time()
  320. while client.cluster.scheduler.tasks:
  321. sleep(0.01)
  322. assert time() < start + 5
  323. assert not client.futures
  324. @pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
  325. @pytest.mark.skipif(
  326. distributed.__version__ <= '2.1.1' and distributed.__version__ >= '1.28.0',
  327. reason="distributed bug - https://github.com/dask/distributed/pull/2841")
  328. def test_wait_for_workers(cluster_strategy):
  329. cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
  330. client = Client(cluster)
  331. if cluster_strategy == "adaptive":
  332. cluster.adapt(minimum=0, maximum=2)
  333. elif cluster_strategy == "late_scaling":
  334. # Tell the cluster to start workers but this is a non-blocking call
  335. # and new workers might take time to connect. In this case the Parallel
  336. # call should wait for at least one worker to come up before starting
  337. # to schedule work.
  338. cluster.scale(2)
  339. try:
  340. with parallel_backend('dask'):
  341. # The following should wait a bit for at least one worker to
  342. # become available.
  343. Parallel()(delayed(inc)(i) for i in range(10))
  344. finally:
  345. client.close()
  346. cluster.close()
  347. def test_wait_for_workers_timeout():
  348. # Start a cluster with 0 worker:
  349. cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
  350. client = Client(cluster)
  351. try:
  352. with parallel_backend('dask', wait_for_workers_timeout=0.1):
  353. # Short timeout: DaskDistributedBackend
  354. msg = "DaskDistributedBackend has no worker after 0.1 seconds."
  355. with pytest.raises(TimeoutError, match=msg):
  356. Parallel()(delayed(inc)(i) for i in range(10))
  357. with parallel_backend('dask', wait_for_workers_timeout=0):
  358. # No timeout: fallback to generic joblib failure:
  359. msg = "DaskDistributedBackend has no active worker"
  360. with pytest.raises(RuntimeError, match=msg):
  361. Parallel()(delayed(inc)(i) for i in range(10))
  362. finally:
  363. client.close()
  364. cluster.close()