Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,6 @@ def __init__(self, auto_requeue: bool = True) -> None:
def creates_processes_externally(self) -> bool:
return True

@staticmethod
def job_id() -> Optional[int]:
job_id = os.environ.get("SLURM_JOB_ID")
if job_id:
try:
job_id = int(job_id)
except ValueError:
job_id = None

# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash"
if in_slurm_interactive_mode:
job_id = None
return job_id

@property
def main_address(self) -> str:
# figure out the root node addr
Expand Down Expand Up @@ -100,6 +85,25 @@ def detect() -> bool:
"""Returns ``True`` if the current process was launched on a SLURM cluster."""
return "SLURM_NTASKS" in os.environ

@staticmethod
def job_name() -> Optional[str]:
return os.environ.get("SLURM_JOB_NAME")

@staticmethod
def job_id() -> Optional[int]:
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = SLURMEnvironment.job_name() == "bash"
if in_slurm_interactive_mode:
return None

job_id = os.environ.get("SLURM_JOB_ID")
if job_id is None:
return None
try:
return int(job_id)
except ValueError:
return None

def world_size(self) -> int:
return int(os.environ["SLURM_NTASKS"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def _is_slurm_managing_tasks(self) -> bool:
if (
(not self.use_ddp and not self.use_ddp2)
or not SLURMEnvironment.detect()
or os.environ.get("SLURM_JOB_NAME") == "bash" # in interactive mode we don't manage tasks
or SLURMEnvironment.job_name() == "bash" # in interactive mode we don't manage tasks
):
return False

Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/environments/test_slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_default_attributes():
assert env.creates_processes_externally
assert env.main_address == "127.0.0.1"
assert env.main_port == 12910
assert env.job_name() is None
assert env.job_id() is None

with pytest.raises(KeyError):
# world size is required to be passed as env variable
env.world_size()
Expand All @@ -48,6 +50,7 @@ def test_default_attributes():
"SLURM_LOCALID": "2",
"SLURM_PROCID": "1",
"SLURM_NODEID": "3",
"SLURM_JOB_NAME": "JOB",
},
)
def test_attributes_from_environment_variables(caplog):
Expand All @@ -61,6 +64,7 @@ def test_attributes_from_environment_variables(caplog):
assert env.global_rank() == 1
assert env.local_rank() == 2
assert env.node_rank() == 3
assert env.job_name() == "JOB"
# setter should be no-op
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
env.set_global_rank(100)
Expand Down