@@ -418,11 +418,7 @@ class Prec(PrecisionPlugin):
418418 class TrainTypePlugin (DDPPlugin ):
419419 pass
420420
421- ttp = TrainTypePlugin (
422- device = torch .device ("cpu" ),
423- accelerator = Accel (),
424- precision_plugin = Prec ()
425- )
421+ ttp = TrainTypePlugin (device = torch .device ("cpu" ), accelerator = Accel (), precision_plugin = Prec ())
426422 trainer = Trainer (strategy = ttp , fast_dev_run = True , num_processes = 2 )
427423 assert isinstance (trainer .accelerator , Accel )
428424 assert isinstance (trainer .training_type_plugin , TrainTypePlugin )
@@ -1038,10 +1034,13 @@ def test_unsupported_tpu_choice(monkeypatch):
10381034 with pytest .raises (MisconfigurationException , match = r"accelerator='tpu', precision=64\)` is not implemented" ):
10391035 Trainer (accelerator = "tpu" , precision = 64 )
10401036
1041- with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but native AMP is not supported" ):
1042- Trainer (accelerator = "tpu" , precision = 16 )
1043- with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but apex AMP is not supported" ):
1044- Trainer (accelerator = "tpu" , precision = 16 , amp_backend = "apex" )
1037+ with pytest .raises (ValueError , match = "TPUAccelerator` can only be used with a `SingleTPUPlugin`" ):
1038+ with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but native AMP is not supported" ):
1039+ Trainer (accelerator = "tpu" , precision = 16 )
1040+
1041+ with pytest .raises (ValueError , match = "TPUAccelerator` can only be used with a `SingleTPUPlugin`" ):
1042+ with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but apex AMP is not supported" ):
1043+ Trainer (accelerator = "tpu" , precision = 16 , amp_backend = "apex" )
10451044
10461045
10471046def test_unsupported_ipu_choice (monkeypatch ):
0 commit comments