Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from toolz import first, assoc
from tornado import gen
from dask import delayed
from distributed.client import _wait, default_client
from distributed.client import wait, default_client
from distributed.utils import sync
import xgboost as xgb

Expand Down Expand Up @@ -105,7 +105,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
# Arrange parts into pairs. This enforces co-locality
parts = list(map(delayed, zip(data_parts, label_parts)))
parts = client.compute(parts) # Start computation in the background
yield _wait(parts)
yield wait(parts)

# Because XGBoost-python doesn't yet allow iterative training, we need to
# find the locations of all chunks and map them to particular Dask workers
Expand All @@ -119,19 +119,20 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):

# Start the XGBoost tracker on the Dask scheduler
host, port = parse_host_port(client.scheduler.address)
env = yield client._run_on_scheduler(start_tracker,
host.strip('/:'),
len(worker_map))
env = yield client.run_on_scheduler(start_tracker,
host.strip('/:'),
len(worker_map))

# Tell each worker to train on the chunks/parts that it has locally
futures = [client.submit(train_part, env,
assoc(params, 'nthreads', ncores[worker]),
list_of_parts, workers=worker,
allow_other_workers=True,
dmatrix_kwargs=dmatrix_kwargs, **kwargs)
for worker, list_of_parts in worker_map.items()]

# Get the results, only one will be non-None
results = yield client._gather(futures)
results = yield client.gather(futures)
result = [v for v in results if v][0]
raise gen.Return(result)

Expand Down