@@ -304,22 +304,22 @@ def test_accelerator_gpu():
304304 assert isinstance (trainer .accelerator , GPUAccelerator )
305305
306306
307- @pytest .mark .parametrize (["devices" , "plugin " ], [(1 , SingleDeviceStrategy ), (5 , DDPSpawnStrategy )])
308- def test_accelerator_cpu_with_devices (devices , plugin ):
307+ @pytest .mark .parametrize (["devices" , "strategy_class " ], [(1 , SingleDeviceStrategy ), (5 , DDPSpawnStrategy )])
308+ def test_accelerator_cpu_with_devices (devices , strategy_class ):
309309 trainer = Trainer (accelerator = "cpu" , devices = devices )
310310 assert trainer .num_devices == devices
311- assert isinstance (trainer .strategy , plugin )
311+ assert isinstance (trainer .strategy , strategy_class )
312312 assert isinstance (trainer .accelerator , CPUAccelerator )
313313
314314
315315@RunIf (min_cuda_gpus = 2 )
316316@pytest .mark .parametrize (
317- ["devices" , "plugin " ], [(1 , SingleDeviceStrategy ), ([1 ], SingleDeviceStrategy ), (2 , DDPSpawnStrategy )]
317+ ["devices" , "strategy_class " ], [(1 , SingleDeviceStrategy ), ([1 ], SingleDeviceStrategy ), (2 , DDPSpawnStrategy )]
318318)
319- def test_accelerator_gpu_with_devices (devices , plugin ):
319+ def test_accelerator_gpu_with_devices (devices , strategy_class ):
320320 trainer = Trainer (accelerator = "gpu" , devices = devices )
321321 assert trainer .num_devices == len (devices ) if isinstance (devices , list ) else devices
322- assert isinstance (trainer .strategy , plugin )
322+ assert isinstance (trainer .strategy , strategy_class )
323323 assert isinstance (trainer .accelerator , GPUAccelerator )
324324
325325
@@ -356,28 +356,28 @@ def test_exception_invalid_strategy():
356356
357357
358358@pytest .mark .parametrize (
359- ["strategy" , "plugin " ],
359+ ["strategy" , "strategy_class " ],
360360 [
361361 ("ddp_spawn" , DDPSpawnStrategy ),
362362 ("ddp_spawn_find_unused_parameters_false" , DDPSpawnStrategy ),
363363 ("ddp" , DDPStrategy ),
364364 ("ddp_find_unused_parameters_false" , DDPStrategy ),
365365 ],
366366)
367- def test_strategy_choice_cpu_str (tmpdir , strategy , plugin ):
367+ def test_strategy_choice_cpu_str (strategy , strategy_class ):
368368 trainer = Trainer (strategy = strategy , accelerator = "cpu" , devices = 2 )
369- assert isinstance (trainer .strategy , plugin )
369+ assert isinstance (trainer .strategy , strategy_class )
370370
371371
372- @pytest .mark .parametrize ("plugin " , [DDPSpawnStrategy , DDPStrategy ])
373- def test_strategy_choice_cpu_plugin ( tmpdir , plugin ):
374- trainer = Trainer (strategy = plugin (), accelerator = "cpu" , devices = 2 )
375- assert isinstance (trainer .strategy , plugin )
372+ @pytest .mark .parametrize ("strategy_class " , [DDPSpawnStrategy , DDPStrategy ])
373+ def test_strategy_choice_cpu_instance ( strategy_class ):
374+ trainer = Trainer (strategy = strategy_class (), accelerator = "cpu" , devices = 2 )
375+ assert isinstance (trainer .strategy , strategy_class )
376376
377377
378378@RunIf (min_cuda_gpus = 2 )
379379@pytest .mark .parametrize (
380- ["strategy" , "plugin " ],
380+ ["strategy" , "strategy_class " ],
381381 [
382382 ("ddp_spawn" , DDPSpawnStrategy ),
383383 ("ddp_spawn_find_unused_parameters_false" , DDPSpawnStrategy ),
@@ -390,29 +390,29 @@ def test_strategy_choice_cpu_plugin(tmpdir, plugin):
390390 pytest .param ("deepspeed" , DeepSpeedStrategy , marks = RunIf (deepspeed = True )),
391391 ],
392392)
393- def test_strategy_choice_gpu_str (tmpdir , strategy , plugin ):
393+ def test_strategy_choice_gpu_str (strategy , strategy_class ):
394394 trainer = Trainer (strategy = strategy , accelerator = "gpu" , devices = 2 )
395- assert isinstance (trainer .strategy , plugin )
395+ assert isinstance (trainer .strategy , strategy_class )
396396
397397
398398@RunIf (min_cuda_gpus = 2 )
399- @pytest .mark .parametrize ("plugin " , [DDPSpawnStrategy , DDPStrategy ])
400- def test_strategy_choice_gpu_plugin ( tmpdir , plugin ):
401- trainer = Trainer (strategy = plugin (), accelerator = "gpu" , devices = 2 )
402- assert isinstance (trainer .strategy , plugin )
399+ @pytest .mark .parametrize ("strategy_class " , [DDPSpawnStrategy , DDPStrategy ])
400+ def test_strategy_choice_gpu_instance ( strategy_class ):
401+ trainer = Trainer (strategy = strategy_class (), accelerator = "gpu" , devices = 2 )
402+ assert isinstance (trainer .strategy , strategy_class )
403403
404404
405405@RunIf (min_cuda_gpus = 2 )
406- @pytest .mark .parametrize ("plugin " , [DDPSpawnStrategy , DDPStrategy ])
407- def test_device_type_when_training_plugin_gpu_passed ( tmpdir , plugin ):
406+ @pytest .mark .parametrize ("strategy_class " , [DDPSpawnStrategy , DDPStrategy ])
407+ def test_device_type_when_strategy_instance_gpu_passed ( strategy_class ):
408408
409- trainer = Trainer (strategy = plugin (), accelerator = "gpu" , devices = 2 )
410- assert isinstance (trainer .strategy , plugin )
409+ trainer = Trainer (strategy = strategy_class (), accelerator = "gpu" , devices = 2 )
410+ assert isinstance (trainer .strategy , strategy_class )
411411 assert isinstance (trainer .accelerator , GPUAccelerator )
412412
413413
414414@pytest .mark .parametrize ("precision" , [1 , 12 , "invalid" ])
415- def test_validate_precision_type (tmpdir , precision ):
415+ def test_validate_precision_type (precision ):
416416
417417 with pytest .raises (MisconfigurationException , match = f"Precision { repr (precision )} is invalid" ):
418418 Trainer (precision = precision )
@@ -423,7 +423,7 @@ def test_amp_level_raises_error_with_native():
423423 _ = Trainer (amp_level = "O2" , amp_backend = "native" , precision = 16 )
424424
425425
426- def test_strategy_choice_ddp_spawn_cpu (tmpdir ):
426+ def test_strategy_choice_ddp_spawn_cpu ():
427427 trainer = Trainer (fast_dev_run = True , strategy = "ddp_spawn" , accelerator = "cpu" , devices = 2 )
428428 assert isinstance (trainer .accelerator , CPUAccelerator )
429429 assert isinstance (trainer .strategy , DDPSpawnStrategy )
0 commit comments