diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d8dc8647bc45..168ce93c103c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -146,6 +146,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) +- Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102)) + + ### Changed diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index 436d40f660e7a..173e533fb0882 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -148,6 +148,7 @@ Cluster Environments ClusterEnvironment LightningEnvironment + LSFEnvironment TorchElasticEnvironment KubeflowEnvironment SLURMEnvironment diff --git a/pytorch_lightning/plugins/environments/__init__.py b/pytorch_lightning/plugins/environments/__init__.py index c7199ece84e31..1878a725071ad 100644 --- a/pytorch_lightning/plugins/environments/__init__.py +++ b/pytorch_lightning/plugins/environments/__init__.py @@ -14,5 +14,6 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 diff --git a/pytorch_lightning/plugins/environments/lsf_environment.py b/pytorch_lightning/plugins/environments/lsf_environment.py new file mode 100644 index 0000000000000..3b32a7b4aeb50 --- /dev/null +++ b/pytorch_lightning/plugins/environments/lsf_environment.py @@ -0,0 +1,160 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket + +from pytorch_lightning import _logger as log +from pytorch_lightning.plugins.environments import ClusterEnvironment + + +class LSFEnvironment(ClusterEnvironment): + """ + An environment for running on clusters managed by the LSF resource manager. + + It is expected that any execution using this ClusterEnvironment was executed + using the Job Step Manager i.e. ``jsrun``. + + This plugin expects the following environment variables. + + LSB_JOBID: + The LSF assigned job ID + + LSB_HOSTS: + The hosts used in the job. This string is expected to have the format "batch ...." + + JSM_NAMESPACE_LOCAL_RANK: + The node local rank for the task. This environment variable is set by jsrun + + JSM_NAMESPACE_SIZE: + The world size for the task. This environment variable is set by jsrun + """ + + def __init__(self): + self._master_address = self._get_master_address() + self._master_port = self._get_master_port() + log.debug(f"MASTER_ADDR: {self._master_address}") + log.debug(f"MASTER_PORT: {self._master_port}") + + @staticmethod + def is_using_lsf() -> bool: + """ Returns ``True`` if the current process was launched using the jsrun command. """ + required_env_vars = ( + "LSB_JOBID", + "LSB_HOSTS", + "JSM_NAMESPACE_LOCAL_RANK", + "JSM_NAMESPACE_SIZE", + ) + return all(v in os.environ for v in required_env_vars) + + def creates_children(self) -> bool: + return True + + def master_address(self): + """ The master address is read from a list of hosts contained in the environment variable `LSB_HOSTS`. """ + return self._master_address + + def master_port(self): + """ THe master port gets calculated from the LSF job ID. """ + return self._master_port + + def world_size(self): + """ The world size is read from the environment variable `JSM_NAMESPACE_SIZE`. """ + var = "JSM_NAMESPACE_SIZE" + world_size = os.environ.get(var) + if world_size is None: + raise ValueError( + f"Cannot determine world size from environment variable {var}." + " Make sure you run your executable with `jsrun`" + ) + return int(world_size) + + def set_world_size(self, size: int) -> None: + log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self): + """ The world size is read from the environment variable `JSM_NAMESPACE_RANK`. """ + var = "JSM_NAMESPACE_RANK" + global_rank = os.environ.get(var) + if global_rank is None: + raise ValueError( + f"Cannot determine global rank from environment variable {var}." + " Make sure you run your executable with `jsrun`" + ) + return int(global_rank) + + def set_global_rank(self, rank: int) -> None: + log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self): + """ The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`. """ + var = "JSM_NAMESPACE_LOCAL_RANK" + local_rank = os.environ.get(var) + if local_rank is None: + raise ValueError( + f"Cannot determine local rank from environment variable {var}." + " Make sure you run your executable with `jsrun`" + ) + return int(local_rank) + + def node_rank(self): + """ + The node rank is determined by the position of the current hostname in the list of hosts stored in + the environment variable `LSB_HOSTS`. + """ + hosts = self._read_hosts() + count = dict() + for host in hosts: + if "batch" in host or "login" in host: + continue + if host not in count: + count[host] = len(count) + return count[socket.gethostname()] + + @staticmethod + def _read_hosts(): + hosts = os.environ.get("LSB_HOSTS") + if not hosts: + raise ValueError("Could not find hosts in environment variable LSB_HOSTS") + hosts = hosts.split() + if len(hosts) < 2: + raise ValueError( + "Cannot parse hosts from LSB_HOSTS environment variable." + " Expected format: \"batch ...\"" + ) + return hosts + + def _get_master_address(self): + hosts = self._read_hosts() + return hosts[1] + + @staticmethod + def _get_master_port(): + """ + A helper function for accessing the master port. + Uses the LSF job ID so all ranks can compute the master port. + """ + # check for user-specified master port + port = os.environ.get("MASTER_PORT") + if not port: + jobid = os.environ.get("LSB_JOBID") + if not jobid: + raise ValueError("Could not find job id in environment variable LSB_JOBID") + port = int(jobid) + # all ports should be in the 10k+ range + port = int(port) % 1000 + 10000 + log.debug(f"calculated LSF master port: {port}") + else: + log.debug(f"using externally specified master port: {port}") + return int(port) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 752ad6eeb747e..2a643c64f5e64 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -54,6 +54,7 @@ ClusterEnvironment, KubeflowEnvironment, LightningEnvironment, + LSFEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) @@ -554,6 +555,8 @@ def select_cluster_environment(self) -> ClusterEnvironment: env = TorchElasticEnvironment() elif KubeflowEnvironment.is_using_kubeflow(): env = KubeflowEnvironment() + elif LSFEnvironment.is_using_lsf(): + env = LSFEnvironment() else: env = LightningEnvironment() return env diff --git a/tests/plugins/environments/test_kubeflow_environment.py b/tests/plugins/environments/test_kubeflow_environment.py index b552b8b4c4c28..767e920920103 100644 --- a/tests/plugins/environments/test_kubeflow_environment.py +++ b/tests/plugins/environments/test_kubeflow_environment.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os from unittest import mock diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py index 8ebcec953fcc8..29917877b2cf5 100644 --- a/tests/plugins/environments/test_lightning_environment.py +++ b/tests/plugins/environments/test_lightning_environment.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from unittest import mock diff --git a/tests/plugins/environments/test_lsf_environment.py b/tests/plugins/environments/test_lsf_environment.py new file mode 100644 index 0000000000000..fd8beec7bb4ac --- /dev/null +++ b/tests/plugins/environments/test_lsf_environment.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import LSFEnvironment + + +@mock.patch.dict(os.environ, { + "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", + "LSB_JOBID": "1234", +}) +def test_missing_lsb_hosts(): + """ Test an error when the lsb hosts list cannot be found. """ + del os.environ["LSB_HOSTS"] + with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_HOSTS"): + LSFEnvironment() + + +@mock.patch.dict(os.environ, { + "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", + "LSB_JOBID": "1234", +}) +def test_missing_lsb_job_id(): + """ Test an error when the job id cannot be found. """ + del os.environ["LSB_JOBID"] + with pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"): + LSFEnvironment() + + +@mock.patch.dict( + os.environ, { + "MASTER_PORT": "4321", + "LSB_JOBID": "1234", + "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", + } +) +def test_manual_master_port_and_address(): + """ Test a user can set the port manually through the MASTER_PORT env variable. """ + env = LSFEnvironment() + assert env.master_port() == 4321 + + +@mock.patch.dict( + os.environ, { + "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1 10.10.10.2 10.10.10.3", + "LSB_JOBID": "1234", + "JSM_NAMESPACE_SIZE": "4", + "JSM_NAMESPACE_RANK": "3", + "JSM_NAMESPACE_LOCAL_RANK": "1" + } +) +def test_attributes_from_environment_variables(): + """ Test that the LSF environment takes the attributes from the environment variables. """ + env = LSFEnvironment() + assert env.creates_children() + assert env.master_address() == "10.10.10.0" + assert env.master_port() == 10234 + assert env.world_size() == 4 + assert env.global_rank() == 3 + assert env.local_rank() == 1 + env.set_global_rank(100) + assert env.global_rank() == 3 + env.set_world_size(100) + assert env.world_size() == 4 + assert LSFEnvironment.is_using_lsf() + + +@mock.patch("socket.gethostname", return_value="host2") +@mock.patch.dict(os.environ, { + "LSB_HOSTS": "batch host0 host1 host2 host3", + "LSB_JOBID": "1234", +}) +def test_node_rank(_): + env = LSFEnvironment() + assert env.node_rank() == 2 diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py index 0be88dbeb91c6..da5fef19e49b5 100644 --- a/tests/plugins/environments/test_slurm_environment.py +++ b/tests/plugins/environments/test_slurm_environment.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os from unittest import mock diff --git a/tests/plugins/environments/test_torchelastic_environment.py b/tests/plugins/environments/test_torchelastic_environment.py index 2b9efafbbcc67..6fee9eb17a6ff 100644 --- a/tests/plugins/environments/test_torchelastic_environment.py +++ b/tests/plugins/environments/test_torchelastic_environment.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os from unittest import mock