@@ -265,6 +265,10 @@ def call(hook, fn, *args, **kwargs):
265265 d = {'name' : hook }
266266 if args :
267267 d ['args' ] = args
268+ elif hook == 'train' :
269+ # DeepSpeed calls `train(mode)` but we do not. Standardize
270+ # https://github.com/microsoft/DeepSpeed/pull/571
271+ d ['args' ] = (True , )
268272 if kwargs :
269273 d ['kwargs' ] = kwargs
270274 called .append (d )
@@ -283,12 +287,13 @@ def test_epoch_end(self, *args, **kwargs):
283287 pass
284288
285289 @staticmethod
286- def _train_batch (trainer , model , batches , current_epoch = 0 ):
290+ def _train_batch (trainer , model , batches , device = torch .device ('cpu' ), current_epoch = 0 , ** kwargs ):
291+ using_native_amp = kwargs .get ('amp_backend' ) == 'native'
287292 out = []
288293 for i in range (batches ):
289294 out .extend ([
290295 dict (name = 'on_before_batch_transfer' , args = (ANY , 0 )),
291- dict (name = 'transfer_batch_to_device' , args = (ANY , torch . device ( 'cpu' ) , 0 )),
296+ dict (name = 'transfer_batch_to_device' , args = (ANY , device , 0 )),
292297 dict (name = 'on_after_batch_transfer' , args = (ANY , 0 )),
293298 # TODO: `on_batch_{start,end}`
294299 dict (name = 'Callback.on_batch_start' , args = (trainer , model )),
@@ -301,14 +306,15 @@ def _train_batch(trainer, model, batches, current_epoch=0):
301306 dict (name = 'on_before_zero_grad' , args = (ANY , )),
302307 dict (name = 'optimizer_zero_grad' , args = (current_epoch , i , ANY , 0 )),
303308 # TODO: `on_before_backward`
304- dict (name = 'backward' , args = (ANY , ANY , 0 )),
309+ # DeepSpeed handles backward internally
310+ * ([dict (name = 'backward' , args = (ANY , ANY , 0 ))] if kwargs .get ('plugins' ) != 'deepspeed' else []),
305311 dict (name = 'Callback.on_after_backward' , args = (trainer , model )),
306312 dict (name = 'on_after_backward' ),
307313 # TODO: `on_before_optimizer_step`
308314 dict (
309315 name = 'optimizer_step' ,
310316 args = (current_epoch , i , ANY , 0 , ANY ),
311- kwargs = dict (on_tpu = False , using_lbfgs = False , using_native_amp = False )
317+ kwargs = dict (on_tpu = False , using_lbfgs = False , using_native_amp = using_native_amp )
312318 ),
313319 dict (name = 'Callback.on_train_batch_end' , args = (trainer , model , dict (loss = ANY ), ANY , i , 0 )),
314320 dict (name = 'on_train_batch_end' , args = (dict (loss = ANY ), ANY , i , 0 )),
@@ -317,14 +323,14 @@ def _train_batch(trainer, model, batches, current_epoch=0):
317323 return out
318324
319325 @staticmethod
320- def _eval_epoch (fn , trainer , model , batches , key ):
326+ def _eval_epoch (fn , trainer , model , batches , key , device = torch . device ( 'cpu' ) ):
321327 outputs = {key : ANY }
322328 return [
323329 dict (name = 'Callback.on_epoch_start' , args = (trainer , model )),
324330 dict (name = 'on_epoch_start' ),
325331 dict (name = f'Callback.on_{ fn } _epoch_start' , args = (trainer , model )),
326332 dict (name = f'on_{ fn } _epoch_start' ),
327- * HookedModel ._eval_batch (fn , trainer , model , batches , key ),
333+ * HookedModel ._eval_batch (fn , trainer , model , batches , key , device = device ),
328334 dict (name = f'{ fn } _epoch_end' , args = ([outputs ] * batches , )),
329335 dict (name = f'Callback.on_{ fn } _epoch_end' , args = (trainer , model )),
330336 dict (name = f'on_{ fn } _epoch_end' ),
@@ -333,13 +339,13 @@ def _eval_epoch(fn, trainer, model, batches, key):
333339 ]
334340
335341 @staticmethod
336- def _eval_batch (fn , trainer , model , batches , key ):
342+ def _eval_batch (fn , trainer , model , batches , key , device = torch . device ( 'cpu' ) ):
337343 out = []
338344 outputs = {key : ANY }
339345 for i in range (batches ):
340346 out .extend ([
341347 dict (name = 'on_before_batch_transfer' , args = (ANY , 0 )),
342- dict (name = 'transfer_batch_to_device' , args = (ANY , torch . device ( 'cpu' ) , 0 )),
348+ dict (name = 'transfer_batch_to_device' , args = (ANY , device , 0 )),
343349 dict (name = 'on_after_batch_transfer' , args = (ANY , 0 )),
344350 # TODO: `{,Callback}.on_batch_{start,end}`
345351 dict (name = f'Callback.on_{ fn } _batch_start' , args = (trainer , model , ANY , i , 0 )),
@@ -357,10 +363,10 @@ def _predict_batch(trainer, model, batches):
357363 out = []
358364 for i in range (batches ):
359365 out .extend ([
360- # TODO: `{,Callback}.on_batch_{start,end}`
361366 dict (name = 'on_before_batch_transfer' , args = (ANY , 0 )),
362367 dict (name = 'transfer_batch_to_device' , args = (ANY , torch .device ('cpu' ), 0 )),
363368 dict (name = 'on_after_batch_transfer' , args = (ANY , 0 )),
369+ # TODO: `{,Callback}.on_batch_{start,end}`
364370 dict (name = 'Callback.on_predict_batch_start' , args = (trainer , model , ANY , i , 0 )),
365371 dict (name = 'on_predict_batch_start' , args = (ANY , i , 0 )),
366372 dict (name = 'forward' , args = (ANY , )),
@@ -372,7 +378,17 @@ def _predict_batch(trainer, model, batches):
372378 return out
373379
374380
375- def test_trainer_model_hook_system_fit (tmpdir ):
381+ @pytest .mark .parametrize (
382+ 'kwargs' ,
383+ [
384+ {},
385+ # these precision plugins modify the optimization flow, so testing them explicitly
386+ pytest .param (dict (gpus = 1 , precision = 16 , plugins = 'deepspeed' ), marks = RunIf (deepspeed = True , min_gpus = 1 )),
387+ pytest .param (dict (gpus = 1 , precision = 16 , amp_backend = 'native' ), marks = RunIf (amp_native = True , min_gpus = 1 )),
388+ pytest .param (dict (gpus = 1 , precision = 16 , amp_backend = 'apex' ), marks = RunIf (amp_apex = True , min_gpus = 1 )),
389+ ]
390+ )
391+ def test_trainer_model_hook_system_fit (tmpdir , kwargs ):
376392 called = []
377393 model = HookedModel (called )
378394 callback = HookedCallback (called )
@@ -385,13 +401,17 @@ def test_trainer_model_hook_system_fit(tmpdir):
385401 limit_val_batches = val_batches ,
386402 progress_bar_refresh_rate = 0 ,
387403 weights_summary = None ,
388- callbacks = [callback ]
404+ callbacks = [callback ],
405+ ** kwargs ,
389406 )
407+
390408 assert called == [
391409 dict (name = 'Callback.on_init_start' , args = (trainer , )),
392410 dict (name = 'Callback.on_init_end' , args = (trainer , )),
393411 ]
412+
394413 trainer .fit (model )
414+
395415 saved_ckpt = {
396416 'callbacks' : ANY ,
397417 'epoch' : 1 ,
@@ -401,19 +421,31 @@ def test_trainer_model_hook_system_fit(tmpdir):
401421 'pytorch-lightning_version' : __version__ ,
402422 'state_dict' : ANY ,
403423 }
424+ if kwargs .get ('amp_backend' ) == 'native' :
425+ saved_ckpt ['native_amp_scaling_state' ] = ANY
426+ elif kwargs .get ('amp_backend' ) == 'apex' :
427+ saved_ckpt ['amp_scaling_state' ] = ANY
428+ device = torch .device ('cuda:0' if 'gpus' in kwargs else 'cpu' )
429+
404430 expected = [
405431 dict (name = 'Callback.on_init_start' , args = (trainer , )),
406432 dict (name = 'Callback.on_init_end' , args = (trainer , )),
407433 dict (name = 'prepare_data' ),
408434 dict (name = 'configure_callbacks' ),
409435 dict (name = 'Callback.on_before_accelerator_backend_setup' , args = (trainer , model )),
436+ # DeepSpeed needs the batch size to figure out throughput logging
437+ * ([dict (name = 'train_dataloader' )] if kwargs .get ('plugins' ) == 'deepspeed' else []),
410438 dict (name = 'Callback.setup' , args = (trainer , model ), kwargs = dict (stage = 'fit' )),
411439 dict (name = 'setup' , kwargs = dict (stage = 'fit' )),
412440 dict (name = 'configure_sharded_model' ),
413441 dict (name = 'Callback.on_configure_sharded_model' , args = (trainer , model )),
414- dict (name = 'configure_optimizers' ),
442+ # DeepSpeed skips initializing optimizers here as they are handled via config
443+ * ([dict (name = 'configure_optimizers' )] if kwargs .get ('plugins' ) != 'deepspeed' else []),
415444 dict (name = 'Callback.on_fit_start' , args = (trainer , model )),
416445 dict (name = 'on_fit_start' ),
446+ # TODO: explore whether DeepSpeed can have the same flow for optimizers
447+ # DeepSpeed did not find any optimizer in the config so they are loaded here
448+ * ([dict (name = 'configure_optimizers' )] if kwargs .get ('plugins' ) == 'deepspeed' else []),
417449 dict (name = 'Callback.on_pretrain_routine_start' , args = (trainer , model )),
418450 dict (name = 'on_pretrain_routine_start' ),
419451 dict (name = 'Callback.on_pretrain_routine_end' , args = (trainer , model )),
@@ -426,14 +458,14 @@ def test_trainer_model_hook_system_fit(tmpdir):
426458 dict (name = 'zero_grad' ),
427459 dict (name = 'Callback.on_validation_start' , args = (trainer , model )),
428460 dict (name = 'on_validation_start' ),
429- * model ._eval_epoch ('validation' , trainer , model , val_batches , 'x' ),
461+ * model ._eval_epoch ('validation' , trainer , model , val_batches , 'x' , device = device ),
430462 dict (name = 'Callback.on_validation_end' , args = (trainer , model )),
431463 dict (name = 'on_validation_end' ),
432- dict (name = 'train' ),
464+ dict (name = 'train' , args = ( True , ) ),
433465 dict (name = 'on_validation_model_train' ),
434466 dict (name = 'Callback.on_sanity_check_end' , args = (trainer , model )),
435467 # duplicate `train` because `_run_train` calls it again in case validation wasn't run
436- dict (name = 'train' ),
468+ dict (name = 'train' , args = ( True , ) ),
437469 dict (name = 'on_train_dataloader' ),
438470 dict (name = 'train_dataloader' ),
439471 dict (name = 'Callback.on_train_start' , args = (trainer , model )),
@@ -442,19 +474,19 @@ def test_trainer_model_hook_system_fit(tmpdir):
442474 dict (name = 'on_epoch_start' ),
443475 dict (name = 'Callback.on_train_epoch_start' , args = (trainer , model )),
444476 dict (name = 'on_train_epoch_start' ),
445- * model ._train_batch (trainer , model , train_batches ),
477+ * model ._train_batch (trainer , model , train_batches , device = device , ** kwargs ),
446478 dict (name = 'train' , args = (False , )),
447479 dict (name = 'on_validation_model_eval' ),
448480 dict (name = 'zero_grad' ),
449481 dict (name = 'Callback.on_validation_start' , args = (trainer , model )),
450482 dict (name = 'on_validation_start' ),
451- * model ._eval_epoch ('validation' , trainer , model , val_batches , 'x' ),
483+ * model ._eval_epoch ('validation' , trainer , model , val_batches , 'x' , device = device ),
452484 dict (name = 'Callback.on_validation_end' , args = (trainer , model )),
453485 # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
454486 dict (name = 'Callback.on_save_checkpoint' , args = (trainer , model , saved_ckpt )),
455487 dict (name = 'on_save_checkpoint' , args = (saved_ckpt , )),
456488 dict (name = 'on_validation_end' ),
457- dict (name = 'train' ),
489+ dict (name = 'train' , args = ( True , ) ),
458490 dict (name = 'on_validation_model_train' ),
459491 dict (name = 'training_epoch_end' , args = ([dict (loss = ANY )] * train_batches , )),
460492 dict (name = 'Callback.on_train_epoch_end' , args = (trainer , model , [dict (loss = ANY )] * train_batches )),
@@ -542,7 +574,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
542574 dict (name = 'on_pretrain_routine_start' ),
543575 dict (name = 'Callback.on_pretrain_routine_end' , args = (trainer , model )),
544576 dict (name = 'on_pretrain_routine_end' ),
545- dict (name = 'train' ),
577+ dict (name = 'train' , args = ( True , ) ),
546578 dict (name = 'on_train_dataloader' ),
547579 dict (name = 'train_dataloader' ),
548580 # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
@@ -610,7 +642,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
610642 * model ._eval_epoch (noun , trainer , model , batches , key ),
611643 dict (name = f'Callback.on_{ noun } _end' , args = (trainer , model )),
612644 dict (name = f'on_{ noun } _end' ),
613- dict (name = 'train' ),
645+ dict (name = 'train' , args = ( True , ) ),
614646 dict (name = f'on_{ noun } _model_train' ),
615647 ]
616648 expected = [
0 commit comments