| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- from __future__ import print_function, division, absolute_import
- import os
- import pytest
- from random import random
- from uuid import uuid4
- from time import sleep
- from .. import Parallel, delayed, parallel_backend
- from ..parallel import ThreadingBackend, AutoBatchingMixin
- from .._dask import DaskDistributedBackend
- distributed = pytest.importorskip('distributed')
- from distributed import Client, LocalCluster, get_client
- from distributed.metrics import time
- from distributed.utils_test import cluster, inc
- def noop(*args, **kwargs):
- pass
- def slow_raise_value_error(condition, duration=0.05):
- sleep(duration)
- if condition:
- raise ValueError("condition evaluated to True")
- def count_events(event_name, client):
- worker_events = client.run(lambda dask_worker: dask_worker.log)
- event_counts = {}
- for w, events in worker_events.items():
- event_counts[w] = len([event for event in list(events)
- if event[1] == event_name])
- return event_counts
- def test_simple(loop):
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask') as (ba, _):
- seq = Parallel()(delayed(inc)(i) for i in range(10))
- assert seq == [inc(i) for i in range(10)]
- with pytest.raises(ValueError):
- Parallel()(delayed(slow_raise_value_error)(i == 3)
- for i in range(10))
- seq = Parallel()(delayed(inc)(i) for i in range(10))
- assert seq == [inc(i) for i in range(10)]
- def test_dask_backend_uses_autobatching(loop):
- assert (DaskDistributedBackend.compute_batch_size
- is AutoBatchingMixin.compute_batch_size)
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask') as (ba, _):
- with Parallel() as parallel:
- # The backend should be initialized with a default
- # batch size of 1:
- backend = parallel._backend
- assert isinstance(backend, DaskDistributedBackend)
- assert backend.parallel is parallel
- assert backend._effective_batch_size == 1
- # Launch many short tasks that should trigger
- # auto-batching:
- parallel(
- delayed(lambda: None)()
- for _ in range(int(1e4))
- )
- assert backend._effective_batch_size > 10
- def random2():
- return random()
- def test_dont_assume_function_purity(loop):
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask') as (ba, _):
- x, y = Parallel()(delayed(random2)() for i in range(2))
- assert x != y
- @pytest.mark.parametrize("mixed", [True, False])
- def test_dask_funcname(loop, mixed):
- from joblib._dask import Batch
- if not mixed:
- tasks = [delayed(inc)(i) for i in range(4)]
- batch_repr = 'batch_of_inc_4_calls'
- else:
- tasks = [
- delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)
- ]
- batch_repr = 'mixed_batch_of_inc_4_calls'
- assert repr(Batch(tasks)) == batch_repr
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client:
- with parallel_backend('dask') as (ba, _):
- _ = Parallel(batch_size=2, pre_dispatch='all')(tasks)
- def f(dask_scheduler):
- return list(dask_scheduler.transition_log)
- batch_repr = batch_repr.replace('4', '2')
- log = client.run_on_scheduler(f)
- assert all('batch_of_inc' in tup[0] for tup in log)
- def test_no_undesired_distributed_cache_hit(loop):
- # Dask has a pickle cache for callables that are called many times. Because
- # the dask backends used to wrapp both the functions and the arguments
- # under instances of the Batch callable class this caching mechanism could
- # lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
- # The joblib-dask backend has been refactored to avoid bundling the
- # arguments as an attribute of the Batch instance to avoid this problem.
- # This test serves as non-regression problem.
- # Use a large number of input arguments to give the AutoBatchingMixin
- # enough tasks to kick-in.
- lists = [[] for _ in range(100)]
- np = pytest.importorskip('numpy')
- X = np.arange(int(1e6))
- def isolated_operation(list_, X=None):
- list_.append(uuid4().hex)
- return list_
- cluster = LocalCluster(n_workers=1, threads_per_worker=2)
- client = Client(cluster)
- try:
- with parallel_backend('dask') as (ba, _):
- # dispatches joblib.parallel.BatchedCalls
- res = Parallel()(
- delayed(isolated_operation)(list_) for list_ in lists
- )
- # The original arguments should not have been mutated as the mutation
- # happens in the dask worker process.
- assert lists == [[] for _ in range(100)]
- # Here we did not pass any large numpy array as argument to
- # isolated_operation so no scattering event should happen under the
- # hood.
- counts = count_events('receive-from-scatter', client)
- assert sum(counts.values()) == 0
- assert all([len(r) == 1 for r in res])
- with parallel_backend('dask') as (ba, _):
- # Append a large array which will be scattered by dask, and
- # dispatch joblib._dask.Batch
- res = Parallel()(
- delayed(isolated_operation)(list_, X=X) for list_ in lists
- )
- # This time, auto-scattering should have kicked it.
- counts = count_events('receive-from-scatter', client)
- assert sum(counts.values()) > 0
- assert all([len(r) == 1 for r in res])
- finally:
- client.close()
- cluster.close()
- class CountSerialized(object):
- def __init__(self, x):
- self.x = x
- self.count = 0
- def __add__(self, other):
- return self.x + getattr(other, 'x', other)
- __radd__ = __add__
- def __reduce__(self):
- self.count += 1
- return (CountSerialized, (self.x,))
- def add5(a, b, c, d=0, e=0):
- return a + b + c + d + e
- def test_manual_scatter(loop):
- x = CountSerialized(1)
- y = CountSerialized(2)
- z = CountSerialized(3)
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask', scatter=[x, y]) as (ba, _):
- f = delayed(add5)
- tasks = [f(x, y, z, d=4, e=5),
- f(x, z, y, d=5, e=4),
- f(y, x, z, d=x, e=5),
- f(z, z, x, d=z, e=y)]
- expected = [func(*args, **kwargs)
- for func, args, kwargs in tasks]
- results = Parallel()(tasks)
- # Scatter must take a list/tuple
- with pytest.raises(TypeError):
- with parallel_backend('dask', loop=loop, scatter=1):
- pass
- assert results == expected
- # Scattered variables only serialized once
- assert x.count == 1
- assert y.count == 1
- # Depending on the version of distributed, the unscattered z variable
- # is either pickled 4 or 6 times, possibly because of the memoization
- # of objects that appear several times in the arguments of a delayed
- # task.
- assert z.count in (4, 6)
- def test_auto_scatter(loop):
- np = pytest.importorskip('numpy')
- data1 = np.ones(int(1e4), dtype=np.uint8)
- data2 = np.ones(int(1e4), dtype=np.uint8)
- data_to_process = ([data1] * 3) + ([data2] * 3)
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client:
- with parallel_backend('dask') as (ba, _):
- # Passing the same data as arg and kwarg triggers a single
- # scatter operation whose result is reused.
- Parallel()(delayed(noop)(data, data, i, opt=data)
- for i, data in enumerate(data_to_process))
- # By default large array are automatically scattered with
- # broadcast=1 which means that one worker must directly receive
- # the data from the scatter operation once.
- counts = count_events('receive-from-scatter', client)
- # assert counts[a['address']] + counts[b['address']] == 2
- assert 2 <= counts[a['address']] + counts[b['address']] <= 4
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client:
- with parallel_backend('dask') as (ba, _):
- Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
- # Small arrays are passed within the task definition without going
- # through a scatter operation.
- counts = count_events('receive-from-scatter', client)
- assert counts[a['address']] == 0
- assert counts[b['address']] == 0
- @pytest.mark.parametrize("retry_no", list(range(2)))
- def test_nested_scatter(loop, retry_no):
- np = pytest.importorskip('numpy')
- NUM_INNER_TASKS = 10
- NUM_OUTER_TASKS = 10
- def my_sum(x, i, j):
- return np.sum(x)
- def outer_function_joblib(array, i):
- client = get_client() # noqa
- with parallel_backend("dask"):
- results = Parallel()(
- delayed(my_sum)(array[j:], i, j) for j in range(
- NUM_INNER_TASKS)
- )
- return sum(results)
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as _:
- with parallel_backend("dask"):
- my_array = np.ones(10000)
- _ = Parallel()(
- delayed(outer_function_joblib)(
- my_array[i:], i) for i in range(NUM_OUTER_TASKS)
- )
- def test_nested_backend_context_manager(loop):
- def get_nested_pids():
- pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
- pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
- return pids
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client:
- with parallel_backend('dask') as (ba, _):
- pid_groups = Parallel(n_jobs=2)(
- delayed(get_nested_pids)()
- for _ in range(10)
- )
- for pid_group in pid_groups:
- assert len(set(pid_group)) <= 2
- # No deadlocks
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask') as (ba, _):
- pid_groups = Parallel(n_jobs=2)(
- delayed(get_nested_pids)()
- for _ in range(10)
- )
- for pid_group in pid_groups:
- assert len(set(pid_group)) <= 2
- def test_nested_backend_context_manager_implicit_n_jobs(loop):
- # Check that Parallel with no explicit n_jobs value automatically selects
- # all the dask workers, including in nested calls.
- def _backend_type(p):
- return p._backend.__class__.__name__
- def get_nested_implicit_n_jobs():
- with Parallel() as p:
- return _backend_type(p), p.n_jobs
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask') as (ba, _):
- with Parallel() as p:
- assert _backend_type(p) == "DaskDistributedBackend"
- assert p.n_jobs == -1
- all_nested_n_jobs = p(
- delayed(get_nested_implicit_n_jobs)()
- for _ in range(2)
- )
- for backend_type, nested_n_jobs in all_nested_n_jobs:
- assert backend_type == "DaskDistributedBackend"
- assert nested_n_jobs == -1
- def test_errors(loop):
- with pytest.raises(ValueError) as info:
- with parallel_backend('dask'):
- pass
- assert "create a dask client" in str(info.value).lower()
- def test_correct_nested_backend(loop):
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- # No requirement, should be us
- with parallel_backend('dask') as (ba, _):
- result = Parallel(n_jobs=2)(
- delayed(outer)(nested_require=None) for _ in range(1))
- assert isinstance(result[0][0][0], DaskDistributedBackend)
- # Require threads, should be threading
- with parallel_backend('dask') as (ba, _):
- result = Parallel(n_jobs=2)(
- delayed(outer)(nested_require='sharedmem')
- for _ in range(1))
- assert isinstance(result[0][0][0], ThreadingBackend)
- def outer(nested_require):
- return Parallel(n_jobs=2, prefer='threads')(
- delayed(middle)(nested_require) for _ in range(1)
- )
- def middle(require):
- return Parallel(n_jobs=2, require=require)(
- delayed(inner)() for _ in range(1)
- )
- def inner():
- return Parallel()._backend
- def test_secede_with_no_processes(loop):
- # https://github.com/dask/distributed/issues/1775
- with Client(loop=loop, processes=False, set_as_default=True):
- with parallel_backend('dask'):
- Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
- def _worker_address(_):
- from distributed import get_worker
- return get_worker().address
- def test_dask_backend_keywords(loop):
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as client: # noqa: F841
- with parallel_backend('dask', workers=a['address']) as (ba, _):
- seq = Parallel()(
- delayed(_worker_address)(i) for i in range(10))
- assert seq == [a['address']] * 10
- with parallel_backend('dask', workers=b['address']) as (ba, _):
- seq = Parallel()(
- delayed(_worker_address)(i) for i in range(10))
- assert seq == [b['address']] * 10
- def test_cleanup(loop):
- with Client(processes=False, loop=loop) as client:
- with parallel_backend('dask'):
- Parallel()(delayed(inc)(i) for i in range(10))
- start = time()
- while client.cluster.scheduler.tasks:
- sleep(0.01)
- assert time() < start + 5
- assert not client.futures
- @pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
- @pytest.mark.skipif(
- distributed.__version__ <= '2.1.1' and distributed.__version__ >= '1.28.0',
- reason="distributed bug - https://github.com/dask/distributed/pull/2841")
- def test_wait_for_workers(cluster_strategy):
- cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
- client = Client(cluster)
- if cluster_strategy == "adaptive":
- cluster.adapt(minimum=0, maximum=2)
- elif cluster_strategy == "late_scaling":
- # Tell the cluster to start workers but this is a non-blocking call
- # and new workers might take time to connect. In this case the Parallel
- # call should wait for at least one worker to come up before starting
- # to schedule work.
- cluster.scale(2)
- try:
- with parallel_backend('dask'):
- # The following should wait a bit for at least one worker to
- # become available.
- Parallel()(delayed(inc)(i) for i in range(10))
- finally:
- client.close()
- cluster.close()
- def test_wait_for_workers_timeout():
- # Start a cluster with 0 worker:
- cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
- client = Client(cluster)
- try:
- with parallel_backend('dask', wait_for_workers_timeout=0.1):
- # Short timeout: DaskDistributedBackend
- msg = "DaskDistributedBackend has no worker after 0.1 seconds."
- with pytest.raises(TimeoutError, match=msg):
- Parallel()(delayed(inc)(i) for i in range(10))
- with parallel_backend('dask', wait_for_workers_timeout=0):
- # No timeout: fallback to generic joblib failure:
- msg = "DaskDistributedBackend has no active worker"
- with pytest.raises(RuntimeError, match=msg):
- Parallel()(delayed(inc)(i) for i in range(10))
- finally:
- client.close()
- cluster.close()
|