|
27 | 27 |
|
28 | 28 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
29 | 29 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
| 30 | +from pytorch_lightning.plugins.environments import SLURMEnvironment |
30 | 31 | from pytorch_lightning.utilities import _TPU_AVAILABLE |
31 | 32 | from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback |
32 | 33 | from tests.helpers import BoringDataModule, BoringModel |
@@ -280,6 +281,22 @@ def on_fit_start(self): |
280 | 281 | assert cli.trainer.ran_asserts |
281 | 282 |
|
282 | 283 |
|
| 284 | +def test_lightning_cli_args_cluster_environments(tmpdir): |
| 285 | + plugins = [dict(class_path='pytorch_lightning.plugins.environments.SLURMEnvironment')] |
| 286 | + |
| 287 | + class TestModel(BoringModel): |
| 288 | + |
| 289 | + def on_fit_start(self): |
| 290 | + # Ensure SLURMEnvironment is set, instead of default LightningEnvironment |
| 291 | + assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment) |
| 292 | + self.trainer.ran_asserts = True |
| 293 | + |
| 294 | + with mock.patch('sys.argv', ['any.py', f'--trainer.plugins={json.dumps(plugins)}']): |
| 295 | + cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) |
| 296 | + |
| 297 | + assert cli.trainer.ran_asserts |
| 298 | + |
| 299 | + |
283 | 300 | def test_lightning_cli_args(tmpdir): |
284 | 301 |
|
285 | 302 | cli_args = [ |
@@ -390,6 +407,7 @@ def test_lightning_cli_print_config(): |
390 | 407 | def test_lightning_cli_submodules(tmpdir): |
391 | 408 |
|
392 | 409 | class MainModule(BoringModel): |
| 410 | + |
393 | 411 | def __init__( |
394 | 412 | self, |
395 | 413 | submodule1: LightningModule, |
|
0 commit comments