|
20 | 20 | import torch |
21 | 21 | import torch.distributed |
22 | 22 |
|
| 23 | +import pytorch_lightning |
23 | 24 | from pytorch_lightning import Trainer |
24 | 25 | from pytorch_lightning.accelerators.accelerator import Accelerator |
25 | 26 | from pytorch_lightning.accelerators.cpu import CPUAccelerator |
@@ -392,19 +393,31 @@ def test_dist_backend_accelerator_mapping(*_): |
392 | 393 | assert trainer.strategy.local_rank == 0 |
393 | 394 |
|
394 | 395 |
|
395 | | -@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) |
396 | 396 | @mock.patch("torch.cuda.device_count", return_value=2) |
397 | | -def test_ipython_incompatible_backend_error(*_): |
| 397 | +def test_ipython_incompatible_backend_error(_, monkeypatch): |
| 398 | + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) |
398 | 399 | with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): |
399 | 400 | Trainer(strategy="ddp", gpus=2) |
400 | 401 |
|
401 | 402 | with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"): |
402 | 403 | Trainer(strategy="ddp2", gpus=2) |
403 | 404 |
|
| 405 | + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_spawn'\)`.*is not compatible"): |
| 406 | + Trainer(strategy="ddp_spawn") |
404 | 407 |
|
405 | | -@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) |
406 | | -def test_ipython_compatible_backend(*_): |
407 | | - Trainer(strategy="ddp_spawn", num_processes=2) |
| 408 | + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_sharded_spawn'\)`.*is not compatible"): |
| 409 | + Trainer(strategy="ddp_sharded_spawn") |
| 410 | + |
| 411 | + with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): |
| 412 | + # Edge case: AcceleratorConnector maps dp to ddp if accelerator != gpu |
| 413 | + Trainer(strategy="dp") |
| 414 | + |
| 415 | + |
| 416 | +@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")]) |
| 417 | +def test_ipython_compatible_backend(trainer_kwargs, monkeypatch): |
| 418 | + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) |
| 419 | + trainer = Trainer(**trainer_kwargs) |
| 420 | + assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible |
408 | 421 |
|
409 | 422 |
|
410 | 423 | @pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")]) |
|
0 commit comments