@@ -59,24 +59,24 @@ def test_unsupported_strategy(strategy):
5959 EmptyLite (strategy = strategy )
6060
6161
62- def test_run_input_output ():
63- """Test that the dynamically patched run() method receives the input arguments and returns the result."""
64-
65- class Lite (LightningLite ):
62+ class LiteReturnSpawnResult (LightningLite ):
63+ def run (self , * args , ** kwargs ):
64+ return args , kwargs , "result" , self .local_rank
6665
67- run_args = ()
68- run_kwargs = {}
6966
70- def run (self , * args , ** kwargs ):
71- self .run_args = args
72- self .run_kwargs = kwargs
73- return "result"
74-
75- lite = Lite ()
67+ @pytest .mark .parametrize (
68+ "accelerator, strategy, devices" ,
69+ [
70+ ("cpu" , None , None ),
71+ ("cpu" , "ddp_spawn" , 2 ),
72+ pytest .param ("tpu" , "tpu_spawn" , 1 , marks = RunIf (tpu = True )),
73+ ],
74+ )
75+ def test_run_input_output (accelerator , strategy , devices ):
76+ """Test that the dynamically patched run() method receives the input arguments and returns the result."""
77+ lite = LiteReturnSpawnResult (accelerator = accelerator , strategy = strategy , devices = devices )
7678 result = lite .run (1 , 2 , three = 3 )
77- assert result == "result"
78- assert lite .run_args == (1 , 2 )
79- assert lite .run_kwargs == {"three" : 3 }
79+ assert result == ((1 , 2 ), {"three" : 3 }, "result" , 0 )
8080
8181
8282def test_setup_optimizers ():
0 commit comments