Skip to content

Commit f44ded5

Browse files
authored
Merge b79156e into 49c579f
2 parents 49c579f + b79156e commit f44ded5

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
101101
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
102102

103103

104+
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
105+
106+
104107
## [1.2.1] - 2021-02-23
105108

106109
### Fixed

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.core.step_result import Result
2020
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
2121
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
22+
from pytorch_lightning.utilities.apply_func import apply_to_collection
2223

2324

2425
class DataParallelPlugin(ParallelPlugin):
@@ -46,8 +47,13 @@ def reduce(self, tensor, *args, **kwargs):
4647
if isinstance(tensor, Result):
4748
tensor.dp_reduce()
4849

49-
elif isinstance(tensor, torch.Tensor):
50-
tensor = tensor.mean()
50+
else:
51+
52+
def _reduce(tensor: torch.Tensor):
53+
dtype_tensor = tensor.dtype
54+
return tensor.float().mean().type(dtype_tensor)
55+
56+
tensor = apply_to_collection(tensor, torch.Tensor, _reduce)
5157

5258
return tensor
5359

tests/accelerators/test_dp.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,22 @@ def test_dp_test(tmpdir):
123123
new_weights = model.layer_0.weight.clone().detach().cpu()
124124

125125
assert torch.all(torch.eq(old_weights, new_weights))
126+
127+
128+
@RunIf(min_gpus=2)
129+
def test_dp_training_step_dict(tmpdir):
130+
"""
131+
This test verify dp properly reduce dictionaries
132+
"""
133+
134+
model = BoringModel()
135+
model.training_step_end = None
136+
trainer = pl.Trainer(
137+
default_root_dir=tmpdir,
138+
max_epochs=1,
139+
limit_train_batches=2,
140+
limit_val_batches=0,
141+
gpus=2,
142+
accelerator='dp',
143+
)
144+
trainer.fit(model)

tests/callbacks/test_pruning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from collections import OrderedDict
1615
from logging import INFO
1716

@@ -22,7 +21,7 @@
2221
from torch.nn import Sequential
2322

2423
from pytorch_lightning import seed_everything, Trainer
25-
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
24+
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
2625
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2726
from tests.helpers import BoringModel
2827
from tests.helpers.runif import RunIf
@@ -274,6 +273,7 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
274273
seed_everything(0)
275274

276275
class TestPruning(ModelPruning):
276+
277277
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
278278
super().on_save_checkpoint(trainer, pl_module, checkpoint)
279279
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]

0 commit comments

Comments
 (0)