Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions pytorch_lightning/plugins/environments/lsf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@ class LSFEnvironment(ClusterEnvironment):
LSB_JOBID:
The LSF assigned job ID

LSB_HOSTS:
The hosts used in the job. This string is expected to have the format "batch <rank_0_host> ...."
LSB_MCPU_HOSTS:
The hosts used in the job. This string is expected to have the format
"<node0_name> <node0_num_procs> <node1_name> ..."

JSM_NAMESPACE_LOCAL_RANK:
The node local rank for the task. This environment variable is set by jsrun

JSM_NAMESPACE_SIZE:
The world size for the task. This environment variable is set by jsrun

More information about environment variables for LSF can be found
`here <https://www.ibm.com/docs/en/spectrum-lsf/10.1.0?topic=variables-environment-variable-reference>`_.
"""

def __init__(self):
Expand All @@ -49,7 +53,7 @@ def __init__(self):
@staticmethod
def is_using_lsf() -> bool:
"""Returns ``True`` if the current process was launched using the jsrun command."""
required_env_vars = ("LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
required_env_vars = ("LSB_JOBID", "LSB_MCPU_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
return all(v in os.environ for v in required_env_vars)

@property
Expand All @@ -58,7 +62,7 @@ def creates_processes_externally(self) -> bool:

@property
def main_address(self) -> str:
"""The main address is read from a list of hosts contained in the environment variable `LSB_HOSTS`."""
"""The main address is read from a list of hosts contained in the environment variable `LSB_MCPU_HOSTS`."""
return self._main_address

@property
Expand Down Expand Up @@ -107,7 +111,7 @@ def local_rank(self):

def node_rank(self):
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
environment variable `LSB_HOSTS`."""
environment variable `LSB_MCPU_HOSTS`."""
hosts = self._read_hosts()
count = {}
for host in hosts:
Expand All @@ -119,19 +123,20 @@ def node_rank(self):

@staticmethod
def _read_hosts():
hosts = os.environ.get("LSB_HOSTS")
if not hosts:
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
hosts = hosts.split()
if len(hosts) < 2:
hosts_config = os.environ.get("LSB_MCPU_HOSTS", "")
if not hosts_config:
raise ValueError("Could not find hosts in environment variable LSB_MCPU_HOSTS")
host_config = hosts_config.split()

if len(host_config) % 2 != 0:
raise ValueError(
'Cannot parse hosts from LSB_HOSTS environment variable. Expected format: "batch <rank_0_host> ..."'
"Cannot parse hosts from LSB_MCPU_HOSTS environment variable. Expected format:"
' "<node0_name> <node0_num_procs> <node1_name> ..."'
)
return hosts
return host_config[::2]

def _get_main_address(self) -> str:
hosts = self._read_hosts()
return hosts[1]
return self._read_hosts()[0]

@staticmethod
def _get_main_port() -> int:
Expand Down
16 changes: 9 additions & 7 deletions tests/plugins/environments/test_lsf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,25 @@
from pytorch_lightning.plugins.environments import LSFEnvironment


@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"})
@mock.patch.dict(os.environ, {"LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1", "LSB_JOBID": "1234"})
def test_missing_lsb_hosts():
"""Test an error when the lsb hosts list cannot be found."""
del os.environ["LSB_HOSTS"]
with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_HOSTS"):
del os.environ["LSB_MCPU_HOSTS"]
with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_MCPU_HOSTS"):
LSFEnvironment()


@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"})
@mock.patch.dict(os.environ, {"LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1", "LSB_JOBID": "1234"})
def test_missing_lsb_job_id():
"""Test an error when the job id cannot be found."""
del os.environ["LSB_JOBID"]
with pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"):
LSFEnvironment()


@mock.patch.dict(os.environ, {"MASTER_PORT": "4321", "LSB_JOBID": "1234", "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1"})
@mock.patch.dict(
os.environ, {"MASTER_PORT": "4321", "LSB_JOBID": "1234", "LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1"}
)
def test_manual_main_port_and_address():
"""Test a user can set the port manually through the MASTER_PORT env variable."""
env = LSFEnvironment()
Expand All @@ -45,7 +47,7 @@ def test_manual_main_port_and_address():
@mock.patch.dict(
os.environ,
{
"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1 10.10.10.2 10.10.10.3",
"LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1 10.10.10.2 1 10.10.10.3 1",
"LSB_JOBID": "1234",
"JSM_NAMESPACE_SIZE": "4",
"JSM_NAMESPACE_RANK": "3",
Expand All @@ -69,7 +71,7 @@ def test_attributes_from_environment_variables():


@mock.patch("socket.gethostname", return_value="host2")
@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch host0 host1 host2 host3", "LSB_JOBID": "1234"})
@mock.patch.dict(os.environ, {"LSB_MCPU_HOSTS": "host0 1 host1 1 host2 1 host3 1", "LSB_JOBID": "1234"})
def test_node_rank(_):
env = LSFEnvironment()
assert env.node_rank() == 2