diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index ad657e1e19564..bde236c672837 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -38,21 +38,6 @@ def __init__(self, auto_requeue: bool = True) -> None: def creates_processes_externally(self) -> bool: return True - @staticmethod - def job_id() -> Optional[int]: - job_id = os.environ.get("SLURM_JOB_ID") - if job_id: - try: - job_id = int(job_id) - except ValueError: - job_id = None - - # in interactive mode, don't make logs use the same job id - in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" - if in_slurm_interactive_mode: - job_id = None - return job_id - @property def main_address(self) -> str: # figure out the root node addr @@ -100,6 +85,25 @@ def detect() -> bool: """Returns ``True`` if the current process was launched on a SLURM cluster.""" return "SLURM_NTASKS" in os.environ + @staticmethod + def job_name() -> Optional[str]: + return os.environ.get("SLURM_JOB_NAME") + + @staticmethod + def job_id() -> Optional[int]: + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = SLURMEnvironment.job_name() == "bash" + if in_slurm_interactive_mode: + return None + + job_id = os.environ.get("SLURM_JOB_ID") + if job_id is None: + return None + try: + return int(job_id) + except ValueError: + return None + def world_size(self) -> int: return int(os.environ["SLURM_NTASKS"]) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 2e5e1c48c5785..18a4da416946d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -1009,7 +1009,7 @@ def _is_slurm_managing_tasks(self) -> bool: if ( (not self.use_ddp and not self.use_ddp2) or not SLURMEnvironment.detect() - or os.environ.get("SLURM_JOB_NAME") == "bash" # in interactive mode we don't manage tasks + or SLURMEnvironment.job_name() == "bash" # in interactive mode we don't manage tasks ): return False diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py index 71e87adb26b69..aa8db284a1c64 100644 --- a/tests/plugins/environments/test_slurm_environment.py +++ b/tests/plugins/environments/test_slurm_environment.py @@ -27,7 +27,9 @@ def test_default_attributes(): assert env.creates_processes_externally assert env.main_address == "127.0.0.1" assert env.main_port == 12910 + assert env.job_name() is None assert env.job_id() is None + with pytest.raises(KeyError): # world size is required to be passed as env variable env.world_size() @@ -48,6 +50,7 @@ def test_default_attributes(): "SLURM_LOCALID": "2", "SLURM_PROCID": "1", "SLURM_NODEID": "3", + "SLURM_JOB_NAME": "JOB", }, ) def test_attributes_from_environment_variables(caplog): @@ -61,6 +64,7 @@ def test_attributes_from_environment_variables(caplog): assert env.global_rank() == 1 assert env.local_rank() == 2 assert env.node_rank() == 3 + assert env.job_name() == "JOB" # setter should be no-op with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): env.set_global_rank(100)