@@ -73,7 +73,14 @@ def train_part(env, param, list_of_parts, dmatrix_kwargs=None, **kwargs):
7373 -------
7474 model if rank zero, None otherwise
7575 """
76- data , labels = zip (* list_of_parts ) # Prepare data
76+ # Prepare data
77+ if len (list_of_parts [0 ]) == 3 :
78+ data , labels , weight = zip (* list_of_parts )
79+ weight = concat (weight )
80+ else :
81+ data , labels = zip (* list_of_parts )
82+ weight = None
83+
7784 data = concat (data ) # Concatenate many parts into one
7885 labels = concat (labels )
7986 if dmatrix_kwargs is None :
@@ -99,7 +106,7 @@ def train_part(env, param, list_of_parts, dmatrix_kwargs=None, **kwargs):
99106
100107
101108@gen .coroutine
102- def _train (client , params , data , labels , dmatrix_kwargs = {}, ** kwargs ):
109+ def _train (client , params , data , labels , sample_weight , dmatrix_kwargs = {}, ** kwargs ):
103110 """
104111 Asynchronous version of train
105112
@@ -117,8 +124,18 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
117124 assert label_parts .ndim == 1 or label_parts .shape [1 ] == 1
118125 label_parts = label_parts .flatten ().tolist ()
119126
120- # Arrange parts into pairs. This enforces co-locality
121- parts = list (map (delayed , zip (data_parts , label_parts )))
127+ if sample_weight is not None :
128+ sample_weight_parts = sample_weight .to_delayed ()
129+ if isinstance (sample_weight_parts , np .ndarray ):
130+ assert sample_weight_parts .ndim == 1 or sample_weight_parts .shape [1 ] == 1
131+ sample_weight_parts = sample_weight_parts .flatten ().tolist ()
132+
133+ # Arrange parts into pairs. This enforces co-locality
134+ parts = list (map (delayed , zip (data_parts , label_parts , sample_weight_parts )))
135+ else :
136+ # Arrange parts into pairs. This enforces co-locality
137+ parts = list (map (delayed , zip (data_parts , label_parts )))
138+
122139 parts = client .compute (parts ) # Start computation in the background
123140 yield wait (parts )
124141
@@ -158,7 +175,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
158175 raise gen .Return (result )
159176
160177
161- def train (client , params , data , labels , dmatrix_kwargs = {}, ** kwargs ):
178+ def train (client , params , data , labels , sample_weight = None , dmatrix_kwargs = {}, ** kwargs ):
162179 """ Train an XGBoost model on a Dask Cluster
163180
164181 This starts XGBoost on all Dask workers, moves input data to those workers,
@@ -188,7 +205,7 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
188205 predict
189206 """
190207 return client .sync (_train , client , params , data ,
191- labels , dmatrix_kwargs , ** kwargs )
208+ labels , sample_weight , dmatrix_kwargs , ** kwargs )
192209
193210
194211def _predict_part (part , model = None ):
@@ -258,7 +275,7 @@ def predict(client, model, data):
258275
259276class XGBRegressor (xgb .XGBRegressor ):
260277
261- def fit (self , X , y = None ):
278+ def fit (self , X , y = None , sample_weight = None ):
262279 """Fit the gradient boosting model
263280
264281 Parameters
@@ -279,6 +296,7 @@ def fit(self, X, y=None):
279296 client = default_client ()
280297 xgb_options = self .get_xgb_params ()
281298 self ._Booster = train (client , xgb_options , X , y ,
299+ sample_weight ,
282300 num_boost_round = self .n_estimators )
283301 return self
284302
@@ -289,7 +307,7 @@ def predict(self, X):
289307
290308class XGBClassifier (xgb .XGBClassifier ):
291309
292- def fit (self , X , y = None , classes = None ):
310+ def fit (self , X , y = None , classes = None , sample_weight = None ):
293311 """Fit a gradient boosting classifier
294312
295313 Parameters
@@ -301,6 +319,8 @@ def fit(self, X, y=None, classes=None):
301319 classes : sequence, optional
302320 The unique values in `y`. If no specified, this will be
303321 eagerly computed from `y` before training.
322+ sample_weight : array-line [n_samples]
323+ Weights for each traning sample
304324
305325 Returns
306326 -------
@@ -345,9 +365,9 @@ def fit(self, X, y=None, classes=None):
345365
346366 # TODO: auto label-encode y
347367 # that will require a dependency on dask-ml
348- # TODO: sample weight
349368
350369 self ._Booster = train (client , xgb_options , X , y ,
370+ sample_weight ,
351371 num_boost_round = self .n_estimators )
352372 return self
353373
0 commit comments