Skip to content

Commit 824c11d

Browse files
committed
update test to ensure spawn result
1 parent d2c27dc commit 824c11d

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

tests/lite/test_lite.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,24 @@ def test_unsupported_strategy(strategy):
5959
EmptyLite(strategy=strategy)
6060

6161

62-
def test_run_input_output():
63-
"""Test that the dynamically patched run() method receives the input arguments and returns the result."""
64-
65-
class Lite(LightningLite):
62+
class LiteReturnSpawnResult(LightningLite):
63+
def run(self, *args, **kwargs):
64+
return args, kwargs, "result", self.local_rank
6665

67-
run_args = ()
68-
run_kwargs = {}
6966

70-
def run(self, *args, **kwargs):
71-
self.run_args = args
72-
self.run_kwargs = kwargs
73-
return "result"
74-
75-
lite = Lite()
67+
@pytest.mark.parametrize(
68+
"accelerator, strategy, devices",
69+
[
70+
("cpu", None, None),
71+
("cpu", "ddp_spawn", 2),
72+
pytest.param("tpu", "tpu_spawn", 1, marks=RunIf(tpu=True)),
73+
],
74+
)
75+
def test_run_input_output(accelerator, strategy, devices):
76+
"""Test that the dynamically patched run() method receives the input arguments and returns the result."""
77+
lite = LiteReturnSpawnResult(accelerator=accelerator, strategy=strategy, devices=devices)
7678
result = lite.run(1, 2, three=3)
77-
assert result == "result"
78-
assert lite.run_args == (1, 2)
79-
assert lite.run_kwargs == {"three": 3}
79+
assert result == ((1, 2), {"three": 3}, "result", 0)
8080

8181

8282
def test_setup_optimizers():

0 commit comments

Comments
 (0)