Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Commit fb1bd39

Browse files
author
Tomas Laube
committed
support sample weights
1 parent 4661c8a commit fb1bd39

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

dask_xgboost/core.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@ def train_part(env, param, list_of_parts, dmatrix_kwargs=None, **kwargs):
7373
-------
7474
model if rank zero, None otherwise
7575
"""
76-
data, labels = zip(*list_of_parts) # Prepare data
76+
# Prepare data
77+
if len(list_of_parts[0]) == 3:
78+
data, labels, weight = zip(*list_of_parts)
79+
weight = concat(weight)
80+
else:
81+
data, labels = zip(*list_of_parts)
82+
weight = None
83+
7784
data = concat(data) # Concatenate many parts into one
7885
labels = concat(labels)
7986
if dmatrix_kwargs is None:
@@ -99,7 +106,7 @@ def train_part(env, param, list_of_parts, dmatrix_kwargs=None, **kwargs):
99106

100107

101108
@gen.coroutine
102-
def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
109+
def _train(client, params, data, labels, sample_weight, dmatrix_kwargs={}, **kwargs):
103110
"""
104111
Asynchronous version of train
105112
@@ -117,8 +124,18 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
117124
assert label_parts.ndim == 1 or label_parts.shape[1] == 1
118125
label_parts = label_parts.flatten().tolist()
119126

120-
# Arrange parts into pairs. This enforces co-locality
121-
parts = list(map(delayed, zip(data_parts, label_parts)))
127+
if sample_weight is not None:
128+
sample_weight_parts = sample_weight.to_delayed()
129+
if isinstance(sample_weight_parts, np.ndarray):
130+
assert sample_weight_parts.ndim == 1 or sample_weight_parts.shape[1] == 1
131+
sample_weight_parts = sample_weight_parts.flatten().tolist()
132+
133+
# Arrange parts into pairs. This enforces co-locality
134+
parts = list(map(delayed, zip(data_parts, label_parts, sample_weight_parts)))
135+
else:
136+
# Arrange parts into pairs. This enforces co-locality
137+
parts = list(map(delayed, zip(data_parts, label_parts)))
138+
122139
parts = client.compute(parts) # Start computation in the background
123140
yield wait(parts)
124141

@@ -158,7 +175,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
158175
raise gen.Return(result)
159176

160177

161-
def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
178+
def train(client, params, data, labels, sample_weight=None, dmatrix_kwargs={}, **kwargs):
162179
""" Train an XGBoost model on a Dask Cluster
163180
164181
This starts XGBoost on all Dask workers, moves input data to those workers,
@@ -188,7 +205,7 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
188205
predict
189206
"""
190207
return client.sync(_train, client, params, data,
191-
labels, dmatrix_kwargs, **kwargs)
208+
labels, sample_weight, dmatrix_kwargs, **kwargs)
192209

193210

194211
def _predict_part(part, model=None):
@@ -258,7 +275,7 @@ def predict(client, model, data):
258275

259276
class XGBRegressor(xgb.XGBRegressor):
260277

261-
def fit(self, X, y=None):
278+
def fit(self, X, y=None, sample_weight=None):
262279
"""Fit the gradient boosting model
263280
264281
Parameters
@@ -279,6 +296,7 @@ def fit(self, X, y=None):
279296
client = default_client()
280297
xgb_options = self.get_xgb_params()
281298
self._Booster = train(client, xgb_options, X, y,
299+
sample_weight,
282300
num_boost_round=self.n_estimators)
283301
return self
284302

@@ -289,7 +307,7 @@ def predict(self, X):
289307

290308
class XGBClassifier(xgb.XGBClassifier):
291309

292-
def fit(self, X, y=None, classes=None):
310+
def fit(self, X, y=None, classes=None, sample_weight=None):
293311
"""Fit a gradient boosting classifier
294312
295313
Parameters
@@ -301,6 +319,8 @@ def fit(self, X, y=None, classes=None):
301319
classes : sequence, optional
302320
The unique values in `y`. If no specified, this will be
303321
eagerly computed from `y` before training.
322+
sample_weight : array-line [n_samples]
323+
Weights for each traning sample
304324
305325
Returns
306326
-------
@@ -345,9 +365,9 @@ def fit(self, X, y=None, classes=None):
345365

346366
# TODO: auto label-encode y
347367
# that will require a dependency on dask-ml
348-
# TODO: sample weight
349368

350369
self._Booster = train(client, xgb_options, X, y,
370+
sample_weight,
351371
num_boost_round=self.n_estimators)
352372
return self
353373

dask_xgboost/tests/test_core.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
X = df.values
3434
y = labels.values
35+
weight = np.ones(10)
3536

3637

3738
def test_classifier(loop): # noqa
@@ -40,7 +41,8 @@ def test_classifier(loop): # noqa
4041
a = dxgb.XGBClassifier()
4142
X2 = da.from_array(X, 5)
4243
y2 = da.from_array(y, 5)
43-
a.fit(X2, y2)
44+
weight1 = da.from_array(weight, 5)
45+
a.fit(X2, y2, sample_weight=weight1)
4446
p1 = a.predict(X2)
4547

4648
b = xgb.XGBClassifier()
@@ -123,13 +125,17 @@ def test_regressor(loop): # noqa
123125
with cluster() as (s, [a, b]):
124126
with Client(s['address'], loop=loop):
125127
a = dxgb.XGBRegressor()
126-
X2 = da.from_array(X, 5)
127-
y2 = da.from_array(y, 5)
128-
a.fit(X2, y2)
128+
X2 = da.from_array(X, 10)
129+
y2 = da.from_array(y, 10)
130+
weight1 = da.from_array(weight, 10)
131+
a.fit(X2, y2, weight1)
129132
p1 = a.predict(X2)
130133

131134
b = xgb.XGBRegressor()
132135
b.fit(X, y)
136+
137+
np.testing.assert_array_almost_equal(a.feature_importances_,
138+
b.feature_importances_)
133139
assert_eq(p1, b.predict(X))
134140

135141

@@ -163,7 +169,7 @@ def test_dmatrix_kwargs(c, s, a, b):
163169
xgb.rabit.init() # workaround for "Doing rabit call after Finalize"
164170
dX = da.from_array(X, chunks=(2, 2))
165171
dy = da.from_array(y, chunks=(2,))
166-
dbst = yield dxgb.train(c, param, dX, dy, {"missing": 0.0})
172+
dbst = yield dxgb.train(c, param, dX, dy, dmatrix_kwargs={"missing": 0.0})
167173

168174
# Distributed model matches local model with dmatrix kwargs
169175
dtrain = xgb.DMatrix(X, label=y, missing=0.0)

0 commit comments

Comments
 (0)