Skip to content

Commit 18a3578

Browse files
amogkamcarmoccas-rogkaushikb11
committed
Automatically set sync_batchnorm for training_type_plugin (#6536)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: Kaushik Bokka <[email protected]>
1 parent c4f262f commit 18a3578

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
425425
if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
426426
training_type.num_nodes = self.num_nodes
427427

428+
# Automatically set sync_batchnorm if None.
429+
# Useful for custom plugins.
430+
if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None:
431+
training_type.sync_batchnorm = self.sync_batchnorm
432+
428433
return training_type
429434

430435
def select_accelerator(self) -> Accelerator:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning import Trainer
15+
from pytorch_lightning.plugins import DDPPlugin
16+
from tests.helpers import BoringModel
17+
from tests.helpers.runif import RunIf
18+
19+
20+
class CustomParallelPlugin(DDPPlugin):
21+
22+
def __init__(self, **kwargs):
23+
super().__init__(**kwargs)
24+
# Set to None so it will be overwritten by the accelerator connector.
25+
self.sync_batchnorm = None
26+
27+
28+
@RunIf(skip_windows=True)
29+
def test_sync_batchnorm_set(tmpdir):
30+
"""Tests if sync_batchnorm is automatically set for custom plugin."""
31+
model = BoringModel()
32+
plugin = CustomParallelPlugin()
33+
assert plugin.sync_batchnorm is None
34+
trainer = Trainer(
35+
max_epochs=1,
36+
plugins=[plugin],
37+
default_root_dir=tmpdir,
38+
sync_batchnorm=True,
39+
)
40+
trainer.fit(model)
41+
assert plugin.sync_batchnorm is True

0 commit comments

Comments
 (0)