diff --git a/pytorch_lightning/plugins/environments/lsf_environment.py b/pytorch_lightning/plugins/environments/lsf_environment.py index 3b67edd8b4091..f7a07449de7b2 100644 --- a/pytorch_lightning/plugins/environments/lsf_environment.py +++ b/pytorch_lightning/plugins/environments/lsf_environment.py @@ -30,14 +30,18 @@ class LSFEnvironment(ClusterEnvironment): LSB_JOBID: The LSF assigned job ID - LSB_HOSTS: - The hosts used in the job. This string is expected to have the format "batch ...." + LSB_MCPU_HOSTS: + The hosts used in the job. This string is expected to have the format + " ..." JSM_NAMESPACE_LOCAL_RANK: The node local rank for the task. This environment variable is set by jsrun JSM_NAMESPACE_SIZE: The world size for the task. This environment variable is set by jsrun + + More information about environment variables for LSF can be found + `here `_. """ def __init__(self): @@ -49,7 +53,7 @@ def __init__(self): @staticmethod def is_using_lsf() -> bool: """Returns ``True`` if the current process was launched using the jsrun command.""" - required_env_vars = ("LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE") + required_env_vars = ("LSB_JOBID", "LSB_MCPU_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE") return all(v in os.environ for v in required_env_vars) @property @@ -58,7 +62,7 @@ def creates_processes_externally(self) -> bool: @property def main_address(self) -> str: - """The main address is read from a list of hosts contained in the environment variable `LSB_HOSTS`.""" + """The main address is read from a list of hosts contained in the environment variable `LSB_MCPU_HOSTS`.""" return self._main_address @property @@ -107,7 +111,7 @@ def local_rank(self): def node_rank(self): """The node rank is determined by the position of the current hostname in the list of hosts stored in the - environment variable `LSB_HOSTS`.""" + environment variable `LSB_MCPU_HOSTS`.""" hosts = self._read_hosts() count = {} for host in hosts: @@ -119,19 +123,20 @@ def node_rank(self): @staticmethod def _read_hosts(): - hosts = os.environ.get("LSB_HOSTS") - if not hosts: - raise ValueError("Could not find hosts in environment variable LSB_HOSTS") - hosts = hosts.split() - if len(hosts) < 2: + hosts_config = os.environ.get("LSB_MCPU_HOSTS", "") + if not hosts_config: + raise ValueError("Could not find hosts in environment variable LSB_MCPU_HOSTS") + host_config = hosts_config.split() + + if len(host_config) % 2 != 0: raise ValueError( - 'Cannot parse hosts from LSB_HOSTS environment variable. Expected format: "batch ..."' + "Cannot parse hosts from LSB_MCPU_HOSTS environment variable. Expected format:" + ' " ..."' ) - return hosts + return host_config[::2] def _get_main_address(self) -> str: - hosts = self._read_hosts() - return hosts[1] + return self._read_hosts()[0] @staticmethod def _get_main_port() -> int: diff --git a/tests/plugins/environments/test_lsf_environment.py b/tests/plugins/environments/test_lsf_environment.py index f438b236d8d37..2a23dd7f5dc41 100644 --- a/tests/plugins/environments/test_lsf_environment.py +++ b/tests/plugins/environments/test_lsf_environment.py @@ -19,15 +19,15 @@ from pytorch_lightning.plugins.environments import LSFEnvironment -@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"}) +@mock.patch.dict(os.environ, {"LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1", "LSB_JOBID": "1234"}) def test_missing_lsb_hosts(): """Test an error when the lsb hosts list cannot be found.""" - del os.environ["LSB_HOSTS"] - with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_HOSTS"): + del os.environ["LSB_MCPU_HOSTS"] + with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_MCPU_HOSTS"): LSFEnvironment() -@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"}) +@mock.patch.dict(os.environ, {"LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1", "LSB_JOBID": "1234"}) def test_missing_lsb_job_id(): """Test an error when the job id cannot be found.""" del os.environ["LSB_JOBID"] @@ -35,7 +35,9 @@ def test_missing_lsb_job_id(): LSFEnvironment() -@mock.patch.dict(os.environ, {"MASTER_PORT": "4321", "LSB_JOBID": "1234", "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1"}) +@mock.patch.dict( + os.environ, {"MASTER_PORT": "4321", "LSB_JOBID": "1234", "LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1"} +) def test_manual_main_port_and_address(): """Test a user can set the port manually through the MASTER_PORT env variable.""" env = LSFEnvironment() @@ -45,7 +47,7 @@ def test_manual_main_port_and_address(): @mock.patch.dict( os.environ, { - "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1 10.10.10.2 10.10.10.3", + "LSB_MCPU_HOSTS": "10.10.10.0 1 10.10.10.1 1 10.10.10.2 1 10.10.10.3 1", "LSB_JOBID": "1234", "JSM_NAMESPACE_SIZE": "4", "JSM_NAMESPACE_RANK": "3", @@ -69,7 +71,7 @@ def test_attributes_from_environment_variables(): @mock.patch("socket.gethostname", return_value="host2") -@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch host0 host1 host2 host3", "LSB_JOBID": "1234"}) +@mock.patch.dict(os.environ, {"LSB_MCPU_HOSTS": "host0 1 host1 1 host2 1 host3 1", "LSB_JOBID": "1234"}) def test_node_rank(_): env = LSFEnvironment() assert env.node_rank() == 2