From 35e8e0548e6404f027bb22241741978b7b73aaed Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 11:25:14 -0700 Subject: [PATCH 01/15] auto sync batchnorm --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 99d716f6b5a8c..20e15e731a7bc 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -426,6 +426,10 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes + if hasattr(training_type, 'sync_batchnorm') and getattr( + training_type, 'sync_batchnorm') is None: + training_type.sync_batchnorm = self.sync_batchnorm + return training_type def select_accelerator(self) -> Accelerator: From 4ce6906ea7e095a0479a039acd8d09e10571a1cb Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 11:25:42 -0700 Subject: [PATCH 02/15] change --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 20e15e731a7bc..351b0a83dcefd 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -426,6 +426,8 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes + # Automatically set sync_batchnorm if not already set. + # Useful for custom plugins. if hasattr(training_type, 'sync_batchnorm') and getattr( training_type, 'sync_batchnorm') is None: training_type.sync_batchnorm = self.sync_batchnorm From 6e29ef67937860606c02bbee16125410d004bb39 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 11:47:28 -0700 Subject: [PATCH 03/15] add test --- .../connectors/accelerator_connector.py | 5 ++--- tests/plugins/test_custom_plugin.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 tests/plugins/test_custom_plugin.py diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 351b0a83dcefd..438e45c9b586b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -426,10 +426,9 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes - # Automatically set sync_batchnorm if not already set. + # Automatically set sync_batchnorm. # Useful for custom plugins. - if hasattr(training_type, 'sync_batchnorm') and getattr( - training_type, 'sync_batchnorm') is None: + if hasattr(training_type, 'sync_batchnorm'): training_type.sync_batchnorm = self.sync_batchnorm return training_type diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py new file mode 100644 index 0000000000000..0bb8a1b3722ca --- /dev/null +++ b/tests/plugins/test_custom_plugin.py @@ -0,0 +1,22 @@ +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDPPlugin +from tests.helpers import BoringModel + + +class CustomParallelPlugin(DDPPlugin): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +def test_sync_batchnorm_set(tmpdir): + model = BoringModel() + plugin = CustomParallelPlugin() + assert plugin.sync_batchnorm == False + trainer = Trainer( + max_epochs=1, + plugins=[plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + assert plugin.sync_batchnorm == True \ No newline at end of file From 65579b013983ce797646dbd660f2cf5c904dbfc2 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 11:48:23 -0700 Subject: [PATCH 04/15] new line --- tests/plugins/test_custom_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 0bb8a1b3722ca..28b078630aef3 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -19,4 +19,4 @@ def test_sync_batchnorm_set(tmpdir): sync_batchnorm=True, ) trainer.fit(model) - assert plugin.sync_batchnorm == True \ No newline at end of file + assert plugin.sync_batchnorm == True From d1ca0a58fa6f1b6f0fd23a7aa0bfacb1ca2ce39f Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 11:49:34 -0700 Subject: [PATCH 05/15] formatting --- tests/plugins/test_custom_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 28b078630aef3..8a64b80e71c02 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -11,7 +11,7 @@ def __init__(self, **kwargs): def test_sync_batchnorm_set(tmpdir): model = BoringModel() plugin = CustomParallelPlugin() - assert plugin.sync_batchnorm == False + assert plugin.sync_batchnorm is False trainer = Trainer( max_epochs=1, plugins=[plugin], @@ -19,4 +19,4 @@ def test_sync_batchnorm_set(tmpdir): sync_batchnorm=True, ) trainer.fit(model) - assert plugin.sync_batchnorm == True + assert plugin.sync_batchnorm is True From d6987164025736677df23e3e338b08f439166f81 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 12:21:16 -0700 Subject: [PATCH 06/15] Update tests/plugins/test_custom_plugin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/plugins/test_custom_plugin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 8a64b80e71c02..9b6561870ba2b 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -3,9 +3,7 @@ from tests.helpers import BoringModel -class CustomParallelPlugin(DDPPlugin): - def __init__(self, **kwargs): - super().__init__(**kwargs) +class CustomParallelPlugin(DDPPlugin): ... def test_sync_batchnorm_set(tmpdir): From 8f08efcc97b0613182a2fc0b837a32df39f9ca22 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 12:23:27 -0700 Subject: [PATCH 07/15] Update tests/plugins/test_custom_plugin.py --- tests/plugins/test_custom_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 9b6561870ba2b..44f033c544ead 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -3,7 +3,8 @@ from tests.helpers import BoringModel -class CustomParallelPlugin(DDPPlugin): ... +class CustomParallelPlugin(DDPPlugin): + ... def test_sync_batchnorm_set(tmpdir): From 1c8cecf9c5b7a8d54634637c0c858055991713cb Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 15 Mar 2021 12:27:32 -0700 Subject: [PATCH 08/15] Update tests/plugins/test_custom_plugin.py --- tests/plugins/test_custom_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 44f033c544ead..9ed73e90f1438 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -3,7 +3,7 @@ from tests.helpers import BoringModel -class CustomParallelPlugin(DDPPlugin): +class CustomParallelPlugin(DDPPlugin): ... From 0f431734fceb22be77ccdc6354f4d63f51a4c309 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 16 Mar 2021 11:48:58 -0700 Subject: [PATCH 09/15] wip --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- tests/plugins/test_custom_plugin.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 438e45c9b586b..27a636e8228a5 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -426,9 +426,9 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes - # Automatically set sync_batchnorm. + # Automatically set sync_batchnorm if not None. # Useful for custom plugins. - if hasattr(training_type, 'sync_batchnorm'): + if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: training_type.sync_batchnorm = self.sync_batchnorm return training_type diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 8a64b80e71c02..b0ac7c06f0914 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -5,7 +5,7 @@ class CustomParallelPlugin(DDPPlugin): def __init__(self, **kwargs): - super().__init__(**kwargs) + super().__init__(sync_batchnorm=None, **kwargs) def test_sync_batchnorm_set(tmpdir): From 162a2b1d61708454c6a1b72d1584a4a3b226a3d5 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 16 Mar 2021 19:18:21 -0700 Subject: [PATCH 10/15] Update pytorch_lightning/trainer/connectors/accelerator_connector.py Co-authored-by: Roger Shieh --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 27a636e8228a5..67bb6653edc31 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -426,7 +426,7 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: training_type.num_nodes = self.num_nodes - # Automatically set sync_batchnorm if not None. + # Automatically set sync_batchnorm if None. # Useful for custom plugins. if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: training_type.sync_batchnorm = self.sync_batchnorm From 91deb34cebe89cfbdb0e558c60eca82acd85bd79 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Wed, 17 Mar 2021 16:29:24 -0700 Subject: [PATCH 11/15] skip test on gpu --- tests/plugins/test_custom_plugin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index b44dc08f67c3d..e555e3039d7f6 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -1,3 +1,5 @@ +import pytest +import torch from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel @@ -9,7 +11,7 @@ def __init__(self, **kwargs): # Set to None so it will be overwritten by the accelerator connector. self.sync_batchnorm = None - +@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") def test_sync_batchnorm_set(tmpdir): """Tests if sync_batchnorm is automatically set for custom plugin.""" model = BoringModel() From e2fafe54da1c51342a9bbd19a1a62d28db90a775 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Wed, 17 Mar 2021 16:31:02 -0700 Subject: [PATCH 12/15] formatting --- tests/plugins/test_custom_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index e555e3039d7f6..d967b897282a6 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -11,6 +11,7 @@ def __init__(self, **kwargs): # Set to None so it will be overwritten by the accelerator connector. self.sync_batchnorm = None + @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") def test_sync_batchnorm_set(tmpdir): """Tests if sync_batchnorm is automatically set for custom plugin.""" From e033f51984fb9c0be7e652dd99a541e8d20239f7 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 19 Mar 2021 12:41:55 +0530 Subject: [PATCH 13/15] skip test for Windows --- tests/plugins/test_custom_plugin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index d967b897282a6..b7e2b671bedb4 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -1,17 +1,21 @@ import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel +from tests.helpers.runif import RunIf class CustomParallelPlugin(DDPPlugin): + def __init__(self, **kwargs): super().__init__(**kwargs) # Set to None so it will be overwritten by the accelerator connector. self.sync_batchnorm = None +@RunIf(skip_windows=True) @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") def test_sync_batchnorm_set(tmpdir): """Tests if sync_batchnorm is automatically set for custom plugin.""" From 84a1a1eb3a23e734dc9b1fb7997f56132998fb04 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 19 Mar 2021 09:08:37 -0700 Subject: [PATCH 14/15] Update tests/plugins/test_custom_plugin.py --- tests/plugins/test_custom_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index b7e2b671bedb4..3ead830ae8fd7 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -16,7 +16,6 @@ def __init__(self, **kwargs): @RunIf(skip_windows=True) -@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") def test_sync_batchnorm_set(tmpdir): """Tests if sync_batchnorm is automatically set for custom plugin.""" model = BoringModel() From 3ff62e1bffbc1742c084324f44535814392355e5 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sat, 20 Mar 2021 01:19:14 +0530 Subject: [PATCH 15/15] Remove unused imports --- tests/plugins/test_custom_plugin.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 3ead830ae8fd7..872b49ef48635 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -1,6 +1,16 @@ -import pytest -import torch - +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin from tests.helpers import BoringModel