diff --git a/CHANGELOG.md b/CHANGELOG.md index d7211081bb374..277fee3463e22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a `teardown` hook to `ClusterEnvironment` ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) + + - Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) @@ -196,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed incorrect removal of `WORLD_SIZE` environment variable in DDP training when launching with torch distributed/torchelastic ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) + + - Set better defaults for `rank_zero_only.rank` when training is launched with SLURM and torchelastic: * Support SLURM and torchelastic global rank environment variables ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) * Remove hardcoding of local rank in accelerator connector ([#6878](https://github.com/PyTorchLightning/pytorch-lightning/pull/6878)) @@ -243,7 +249,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) -- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917)) +- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917)) - Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index 9728fba932874..ed6172ae663ce 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -52,3 +52,7 @@ def local_rank(self) -> int: @abstractmethod def node_rank(self) -> int: """ The rank (index) of the node on which the current process runs. """ + + def teardown(self) -> None: + """ Clean up any state set after execution finishes. """ + pass diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py index 67752535fe4e1..25da0cfb691e8 100644 --- a/pytorch_lightning/plugins/environments/lightning_environment.py +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -68,6 +68,10 @@ def node_rank(self) -> int: group_rank = os.environ.get("GROUP_RANK", 0) return int(os.environ.get("NODE_RANK", group_rank)) + def teardown(self) -> None: + if "WORLD_SIZE" in os.environ: + del os.environ["WORLD_SIZE"] + def find_free_network_port() -> int: """ diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7e9624d9a0122..977145a4cc7ba 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -280,9 +280,8 @@ def pre_dispatch(self): self.barrier() - def post_dispatch(self): - if "WORLD_SIZE" in os.environ: - del os.environ["WORLD_SIZE"] + def post_dispatch(self) -> None: + self.cluster_environment.teardown() def barrier(self, *args, **kwargs): if torch_distrib.is_initialized(): diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py index 3f89b88bfc215..8ebcec953fcc8 100644 --- a/tests/plugins/environments/test_lightning_environment.py +++ b/tests/plugins/environments/test_lightning_environment.py @@ -55,3 +55,14 @@ def test_random_master_port(): assert isinstance(port, int) # repeated calls do not generate a new port number assert env.master_port() == port + + +@mock.patch.dict(os.environ, { + "WORLD_SIZE": "1", +}) +def test_teardown(): + """ Test that the GROUP_RANK substitutes NODE_RANK. """ + env = LightningEnvironment() + assert "WORLD_SIZE" in os.environ + env.teardown() + assert "WORLD_SIZE" not in os.environ