From 6f6502b804ec74d8378b4b0e2598f319965cd792 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 9 Apr 2021 21:28:48 -0700 Subject: [PATCH 1/7] Update ddp.py --- pytorch_lightning/plugins/training_type/ddp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7e9624d9a0122..a03daac427041 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -280,7 +280,12 @@ def pre_dispatch(self): self.barrier() - def post_dispatch(self): + def post_dispatch(self) -> None: + # If we've spawned processes within the trainer, remove the populated environment variables + # RFC: should we use environment variables specific to lightning spawning? world size is also used by torchelastic + # why doesn't this happen in teardown? + if self.cluster_environment.creates_children: + return if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] From f61aedadd520cec1366edbca1e29e3606221d4f6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 9 Apr 2021 21:35:31 -0700 Subject: [PATCH 2/7] Update ddp.py --- pytorch_lightning/plugins/training_type/ddp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a03daac427041..30cfac1a08ea3 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -281,9 +281,7 @@ def pre_dispatch(self): self.barrier() def post_dispatch(self) -> None: - # If we've spawned processes within the trainer, remove the populated environment variables - # RFC: should we use environment variables specific to lightning spawning? world size is also used by torchelastic - # why doesn't this happen in teardown? + # If the plugin launched subprocesses, clean up the populated environment variable(s) if self.cluster_environment.creates_children: return if "WORLD_SIZE" in os.environ: From 413bdb91f498b1306e39bc3e43b367da7d88556a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 07:55:27 -0700 Subject: [PATCH 3/7] teardown-env --- .../plugins/environments/cluster_environment.py | 4 ++++ .../plugins/environments/lightning_environment.py | 4 ++++ pytorch_lightning/plugins/training_type/ddp.py | 6 +----- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index 9728fba932874..e59ba2037eed1 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -52,3 +52,7 @@ def local_rank(self) -> int: @abstractmethod def node_rank(self) -> int: """ The rank (index) of the node on which the current process runs. """ + + def teardown(self) -> None: + """ Clean up any environment variables after execution finishes. """ + pass diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py index 67752535fe4e1..25da0cfb691e8 100644 --- a/pytorch_lightning/plugins/environments/lightning_environment.py +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -68,6 +68,10 @@ def node_rank(self) -> int: group_rank = os.environ.get("GROUP_RANK", 0) return int(os.environ.get("NODE_RANK", group_rank)) + def teardown(self) -> None: + if "WORLD_SIZE" in os.environ: + del os.environ["WORLD_SIZE"] + def find_free_network_port() -> int: """ diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 30cfac1a08ea3..977145a4cc7ba 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -281,11 +281,7 @@ def pre_dispatch(self): self.barrier() def post_dispatch(self) -> None: - # If the plugin launched subprocesses, clean up the populated environment variable(s) - if self.cluster_environment.creates_children: - return - if "WORLD_SIZE" in os.environ: - del os.environ["WORLD_SIZE"] + self.cluster_environment.teardown() def barrier(self, *args, **kwargs): if torch_distrib.is_initialized(): From 181aba506f6a551b2d908dc4a7e0b806a8e392af Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 08:00:46 -0700 Subject: [PATCH 4/7] Update cluster_environment.py --- pytorch_lightning/plugins/environments/cluster_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index e59ba2037eed1..ed6172ae663ce 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -54,5 +54,5 @@ def node_rank(self) -> int: """ The rank (index) of the node on which the current process runs. """ def teardown(self) -> None: - """ Clean up any environment variables after execution finishes. """ + """ Clean up any state set after execution finishes. """ pass From ef66fa30e8aab035b81a9ca9f157d656752800d9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 08:32:08 -0700 Subject: [PATCH 5/7] Update CHANGELOG.md --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7211081bb374..ac783f90aa355 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.3.0] - 2021-MM-DD ### Added +- Added a `teardown` hook to `ClusterEnvironment` ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) + - Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) @@ -195,6 +197,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed incorrect removal of `WORLD_SIZE` environment variable in DDP training when launching with torch distributed/torchelastic ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) + - Set better defaults for `rank_zero_only.rank` when training is launched with SLURM and torchelastic: * Support SLURM and torchelastic global rank environment variables ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) From bf4eb9c6c5b5e399e4dbdd7f9ec3a0a2b8f89e77 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 08:37:01 -0700 Subject: [PATCH 6/7] Update test_lightning_environment.py --- .../environments/test_lightning_environment.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py index 3f89b88bfc215..8ebcec953fcc8 100644 --- a/tests/plugins/environments/test_lightning_environment.py +++ b/tests/plugins/environments/test_lightning_environment.py @@ -55,3 +55,14 @@ def test_random_master_port(): assert isinstance(port, int) # repeated calls do not generate a new port number assert env.master_port() == port + + +@mock.patch.dict(os.environ, { + "WORLD_SIZE": "1", +}) +def test_teardown(): + """ Test that the GROUP_RANK substitutes NODE_RANK. """ + env = LightningEnvironment() + assert "WORLD_SIZE" in os.environ + env.teardown() + assert "WORLD_SIZE" not in os.environ From 3139f05f6b28b6bad1adce95440af07a91128844 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 08:39:36 -0700 Subject: [PATCH 7/7] Update CHANGELOG.md --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac783f90aa355..277fee3463e22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.3.0] - 2021-MM-DD ### Added + - Added a `teardown` hook to `ClusterEnvironment` ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) @@ -197,6 +198,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + - Fixed incorrect removal of `WORLD_SIZE` environment variable in DDP training when launching with torch distributed/torchelastic ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) @@ -247,7 +249,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898)) -- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917)) +- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917)) - Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))