Skip to content

Commit ea88105

Browse files
carmoccaawaelchli
andauthored
Parametrize fit hook test with different precision plugins (#8070)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 7b6d0a8 commit ea88105

File tree

3 files changed

+55
-24
lines changed

3 files changed

+55
-24
lines changed

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,14 @@ def backward(
6464
) -> Tensor:
6565
if is_overridden('backward', model):
6666
warning_cache.warn(
67-
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
68-
"backward logic outside of the LightningModule"
67+
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
68+
" the backward logic internally."
6969
)
7070
# todo: hack around for deepspeed engine to call backward
7171
deepspeed_engine = model.trainer.model
7272
deepspeed_engine.backward(closure_loss, *args, **kwargs)
7373
# once backward has been applied, release graph
7474
closure_loss = closure_loss.detach()
75-
7675
return closure_loss
7776

7877
def clip_gradients(

tests/models/test_hooks.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ def call(hook, fn, *args, **kwargs):
265265
d = {'name': hook}
266266
if args:
267267
d['args'] = args
268+
elif hook == 'train':
269+
# DeepSpeed calls `train(mode)` but we do not. Standardize
270+
# https://github.com/microsoft/DeepSpeed/pull/571
271+
d['args'] = (True, )
268272
if kwargs:
269273
d['kwargs'] = kwargs
270274
called.append(d)
@@ -283,12 +287,13 @@ def test_epoch_end(self, *args, **kwargs):
283287
pass
284288

285289
@staticmethod
286-
def _train_batch(trainer, model, batches, current_epoch=0):
290+
def _train_batch(trainer, model, batches, device=torch.device('cpu'), current_epoch=0, **kwargs):
291+
using_native_amp = kwargs.get('amp_backend') == 'native'
287292
out = []
288293
for i in range(batches):
289294
out.extend([
290295
dict(name='on_before_batch_transfer', args=(ANY, 0)),
291-
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
296+
dict(name='transfer_batch_to_device', args=(ANY, device, 0)),
292297
dict(name='on_after_batch_transfer', args=(ANY, 0)),
293298
# TODO: `on_batch_{start,end}`
294299
dict(name='Callback.on_batch_start', args=(trainer, model)),
@@ -301,14 +306,15 @@ def _train_batch(trainer, model, batches, current_epoch=0):
301306
dict(name='on_before_zero_grad', args=(ANY, )),
302307
dict(name='optimizer_zero_grad', args=(current_epoch, i, ANY, 0)),
303308
# TODO: `on_before_backward`
304-
dict(name='backward', args=(ANY, ANY, 0)),
309+
# DeepSpeed handles backward internally
310+
*([dict(name='backward', args=(ANY, ANY, 0))] if kwargs.get('plugins') != 'deepspeed' else []),
305311
dict(name='Callback.on_after_backward', args=(trainer, model)),
306312
dict(name='on_after_backward'),
307313
# TODO: `on_before_optimizer_step`
308314
dict(
309315
name='optimizer_step',
310316
args=(current_epoch, i, ANY, 0, ANY),
311-
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False)
317+
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp)
312318
),
313319
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
314320
dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)),
@@ -317,14 +323,14 @@ def _train_batch(trainer, model, batches, current_epoch=0):
317323
return out
318324

319325
@staticmethod
320-
def _eval_epoch(fn, trainer, model, batches, key):
326+
def _eval_epoch(fn, trainer, model, batches, key, device=torch.device('cpu')):
321327
outputs = {key: ANY}
322328
return [
323329
dict(name='Callback.on_epoch_start', args=(trainer, model)),
324330
dict(name='on_epoch_start'),
325331
dict(name=f'Callback.on_{fn}_epoch_start', args=(trainer, model)),
326332
dict(name=f'on_{fn}_epoch_start'),
327-
*HookedModel._eval_batch(fn, trainer, model, batches, key),
333+
*HookedModel._eval_batch(fn, trainer, model, batches, key, device=device),
328334
dict(name=f'{fn}_epoch_end', args=([outputs] * batches, )),
329335
dict(name=f'Callback.on_{fn}_epoch_end', args=(trainer, model)),
330336
dict(name=f'on_{fn}_epoch_end'),
@@ -333,13 +339,13 @@ def _eval_epoch(fn, trainer, model, batches, key):
333339
]
334340

335341
@staticmethod
336-
def _eval_batch(fn, trainer, model, batches, key):
342+
def _eval_batch(fn, trainer, model, batches, key, device=torch.device('cpu')):
337343
out = []
338344
outputs = {key: ANY}
339345
for i in range(batches):
340346
out.extend([
341347
dict(name='on_before_batch_transfer', args=(ANY, 0)),
342-
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
348+
dict(name='transfer_batch_to_device', args=(ANY, device, 0)),
343349
dict(name='on_after_batch_transfer', args=(ANY, 0)),
344350
# TODO: `{,Callback}.on_batch_{start,end}`
345351
dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)),
@@ -357,10 +363,10 @@ def _predict_batch(trainer, model, batches):
357363
out = []
358364
for i in range(batches):
359365
out.extend([
360-
# TODO: `{,Callback}.on_batch_{start,end}`
361366
dict(name='on_before_batch_transfer', args=(ANY, 0)),
362367
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
363368
dict(name='on_after_batch_transfer', args=(ANY, 0)),
369+
# TODO: `{,Callback}.on_batch_{start,end}`
364370
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
365371
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
366372
dict(name='forward', args=(ANY, )),
@@ -372,7 +378,17 @@ def _predict_batch(trainer, model, batches):
372378
return out
373379

374380

375-
def test_trainer_model_hook_system_fit(tmpdir):
381+
@pytest.mark.parametrize(
382+
'kwargs',
383+
[
384+
{},
385+
# these precision plugins modify the optimization flow, so testing them explicitly
386+
pytest.param(dict(gpus=1, precision=16, plugins='deepspeed'), marks=RunIf(deepspeed=True, min_gpus=1)),
387+
pytest.param(dict(gpus=1, precision=16, amp_backend='native'), marks=RunIf(amp_native=True, min_gpus=1)),
388+
pytest.param(dict(gpus=1, precision=16, amp_backend='apex'), marks=RunIf(amp_apex=True, min_gpus=1)),
389+
]
390+
)
391+
def test_trainer_model_hook_system_fit(tmpdir, kwargs):
376392
called = []
377393
model = HookedModel(called)
378394
callback = HookedCallback(called)
@@ -385,13 +401,17 @@ def test_trainer_model_hook_system_fit(tmpdir):
385401
limit_val_batches=val_batches,
386402
progress_bar_refresh_rate=0,
387403
weights_summary=None,
388-
callbacks=[callback]
404+
callbacks=[callback],
405+
**kwargs,
389406
)
407+
390408
assert called == [
391409
dict(name='Callback.on_init_start', args=(trainer, )),
392410
dict(name='Callback.on_init_end', args=(trainer, )),
393411
]
412+
394413
trainer.fit(model)
414+
395415
saved_ckpt = {
396416
'callbacks': ANY,
397417
'epoch': 1,
@@ -401,19 +421,31 @@ def test_trainer_model_hook_system_fit(tmpdir):
401421
'pytorch-lightning_version': __version__,
402422
'state_dict': ANY,
403423
}
424+
if kwargs.get('amp_backend') == 'native':
425+
saved_ckpt['native_amp_scaling_state'] = ANY
426+
elif kwargs.get('amp_backend') == 'apex':
427+
saved_ckpt['amp_scaling_state'] = ANY
428+
device = torch.device('cuda:0' if 'gpus' in kwargs else 'cpu')
429+
404430
expected = [
405431
dict(name='Callback.on_init_start', args=(trainer, )),
406432
dict(name='Callback.on_init_end', args=(trainer, )),
407433
dict(name='prepare_data'),
408434
dict(name='configure_callbacks'),
409435
dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)),
436+
# DeepSpeed needs the batch size to figure out throughput logging
437+
*([dict(name='train_dataloader')] if kwargs.get('plugins') == 'deepspeed' else []),
410438
dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')),
411439
dict(name='setup', kwargs=dict(stage='fit')),
412440
dict(name='configure_sharded_model'),
413441
dict(name='Callback.on_configure_sharded_model', args=(trainer, model)),
414-
dict(name='configure_optimizers'),
442+
# DeepSpeed skips initializing optimizers here as they are handled via config
443+
*([dict(name='configure_optimizers')] if kwargs.get('plugins') != 'deepspeed' else []),
415444
dict(name='Callback.on_fit_start', args=(trainer, model)),
416445
dict(name='on_fit_start'),
446+
# TODO: explore whether DeepSpeed can have the same flow for optimizers
447+
# DeepSpeed did not find any optimizer in the config so they are loaded here
448+
*([dict(name='configure_optimizers')] if kwargs.get('plugins') == 'deepspeed' else []),
417449
dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)),
418450
dict(name='on_pretrain_routine_start'),
419451
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
@@ -426,14 +458,14 @@ def test_trainer_model_hook_system_fit(tmpdir):
426458
dict(name='zero_grad'),
427459
dict(name='Callback.on_validation_start', args=(trainer, model)),
428460
dict(name='on_validation_start'),
429-
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
461+
*model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device),
430462
dict(name='Callback.on_validation_end', args=(trainer, model)),
431463
dict(name='on_validation_end'),
432-
dict(name='train'),
464+
dict(name='train', args=(True, )),
433465
dict(name='on_validation_model_train'),
434466
dict(name='Callback.on_sanity_check_end', args=(trainer, model)),
435467
# duplicate `train` because `_run_train` calls it again in case validation wasn't run
436-
dict(name='train'),
468+
dict(name='train', args=(True, )),
437469
dict(name='on_train_dataloader'),
438470
dict(name='train_dataloader'),
439471
dict(name='Callback.on_train_start', args=(trainer, model)),
@@ -442,19 +474,19 @@ def test_trainer_model_hook_system_fit(tmpdir):
442474
dict(name='on_epoch_start'),
443475
dict(name='Callback.on_train_epoch_start', args=(trainer, model)),
444476
dict(name='on_train_epoch_start'),
445-
*model._train_batch(trainer, model, train_batches),
477+
*model._train_batch(trainer, model, train_batches, device=device, **kwargs),
446478
dict(name='train', args=(False, )),
447479
dict(name='on_validation_model_eval'),
448480
dict(name='zero_grad'),
449481
dict(name='Callback.on_validation_start', args=(trainer, model)),
450482
dict(name='on_validation_start'),
451-
*model._eval_epoch('validation', trainer, model, val_batches, 'x'),
483+
*model._eval_epoch('validation', trainer, model, val_batches, 'x', device=device),
452484
dict(name='Callback.on_validation_end', args=(trainer, model)),
453485
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end`
454486
dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)),
455487
dict(name='on_save_checkpoint', args=(saved_ckpt, )),
456488
dict(name='on_validation_end'),
457-
dict(name='train'),
489+
dict(name='train', args=(True, )),
458490
dict(name='on_validation_model_train'),
459491
dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )),
460492
dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)),
@@ -542,7 +574,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
542574
dict(name='on_pretrain_routine_start'),
543575
dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
544576
dict(name='on_pretrain_routine_end'),
545-
dict(name='train'),
577+
dict(name='train', args=(True, )),
546578
dict(name='on_train_dataloader'),
547579
dict(name='train_dataloader'),
548580
# even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
@@ -610,7 +642,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
610642
*model._eval_epoch(noun, trainer, model, batches, key),
611643
dict(name=f'Callback.on_{noun}_end', args=(trainer, model)),
612644
dict(name=f'on_{noun}_end'),
613-
dict(name='train'),
645+
dict(name='train', args=(True, )),
614646
dict(name=f'on_{noun}_model_train'),
615647
]
616648
expected = [

tests/plugins/test_deepspeed_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
256256
gpus=1,
257257
precision=16,
258258
)
259-
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'):
259+
with pytest.warns(UserWarning, match='will be ignored since DeepSpeed handles the backward'):
260260
trainer.fit(model)
261261

262262

0 commit comments

Comments
 (0)