Skip to content

Commit 597b309

Browse files
leezuawaelchli
andauthored
Fix Trainer.plugins type declaration (#7288)
* Fix trainer.plugins type declaration * Don't ClusterEnvironment(Plugin) * fix import error, yapf formatter * Add test Co-authored-by: Adrian Wälchli <[email protected]>
1 parent f135deb commit 597b309

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytorch_lightning.core.step_result import Result
3131
from pytorch_lightning.loggers import LightningLoggerBase
3232
from pytorch_lightning.plugins import Plugin
33+
from pytorch_lightning.plugins.environments import ClusterEnvironment
3334
from pytorch_lightning.profiler import BaseProfiler
3435
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
3536
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
@@ -137,7 +138,7 @@ def __init__(
137138
terminate_on_nan: bool = False,
138139
auto_scale_batch_size: Union[str, bool] = False,
139140
prepare_data_per_node: bool = True,
140-
plugins: Optional[Union[Plugin, str, list]] = None,
141+
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
141142
amp_backend: str = 'native',
142143
amp_level: str = 'O2',
143144
distributed_backend: Optional[str] = None,

tests/utilities/test_cli.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
2929
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
30+
from pytorch_lightning.plugins.environments import SLURMEnvironment
3031
from pytorch_lightning.utilities import _TPU_AVAILABLE
3132
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
3233
from tests.helpers import BoringDataModule, BoringModel
@@ -280,6 +281,22 @@ def on_fit_start(self):
280281
assert cli.trainer.ran_asserts
281282

282283

284+
def test_lightning_cli_args_cluster_environments(tmpdir):
285+
plugins = [dict(class_path='pytorch_lightning.plugins.environments.SLURMEnvironment')]
286+
287+
class TestModel(BoringModel):
288+
289+
def on_fit_start(self):
290+
# Ensure SLURMEnvironment is set, instead of default LightningEnvironment
291+
assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment)
292+
self.trainer.ran_asserts = True
293+
294+
with mock.patch('sys.argv', ['any.py', f'--trainer.plugins={json.dumps(plugins)}']):
295+
cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
296+
297+
assert cli.trainer.ran_asserts
298+
299+
283300
def test_lightning_cli_args(tmpdir):
284301

285302
cli_args = [
@@ -390,6 +407,7 @@ def test_lightning_cli_print_config():
390407
def test_lightning_cli_submodules(tmpdir):
391408

392409
class MainModule(BoringModel):
410+
393411
def __init__(
394412
self,
395413
submodule1: LightningModule,

0 commit comments

Comments
 (0)