Skip to content

Commit 03bb389

Browse files
ethanwharrisBorda
andauthored
Fix double precision + ddp_spawn (#6924)
* Initial fix * Initial fix * Initial fix * Updates * Updates * Update typing and docs * Undo accidental refactor * Remove unused imports * Add DDP double precision test * Remove unused variable * Update CHANGELOG.md * Fix test * Update tests * Formatting * Revert bad change * Add back changes * Correct wrapping order * Improve unwrapping * Correct wrapping order * Fix... finally * Respond to comments * Drop ddp test * Simplify ddp spawn test * Simplify ddp spawn test Co-authored-by: Jirka Borovec <[email protected]>
1 parent 195b24b commit 03bb389

File tree

7 files changed

+186
-59
lines changed

7 files changed

+186
-59
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
179179
- Fixed formatting of info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))
180180

181181

182+
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
183+
184+
182185
## [1.3.2] - 2021-05-18
183186

184187
### Changed

pytorch_lightning/overrides/base.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Any, Union
15+
1416
import torch
1517
from torch.nn import DataParallel
1618
from torch.nn.parallel import DistributedDataParallel
@@ -19,9 +21,44 @@
1921
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
2022

2123

22-
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
24+
class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
2325

2426
def __init__(self, pl_module: 'pl.LightningModule') -> None:
27+
"""
28+
Wraps the user's LightningModule. Requires overriding all ``*_step`` methods and ``forward`` so that it can
29+
safely be wrapped by a ``_LightningModuleWrapperBase`` and a ``*DataParallel``.
30+
31+
Args:
32+
pl_module: the model to wrap
33+
"""
34+
super().__init__()
35+
self.module = pl_module
36+
37+
# set the parameters_to_ignore from LightningModule.
38+
self._ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", [])
39+
40+
def training_step(self, *args: Any, **kwargs: Any) -> Any:
41+
raise NotImplementedError
42+
43+
def validation_step(self, *args: Any, **kwargs: Any) -> Any:
44+
raise NotImplementedError
45+
46+
def test_step(self, *args: Any, **kwargs: Any) -> Any:
47+
raise NotImplementedError
48+
49+
def predict_step(self, *args: Any, **kwargs: Any) -> Any:
50+
raise NotImplementedError
51+
52+
def forward(self, *args: Any, **kwargs: Any) -> Any:
53+
raise NotImplementedError
54+
55+
def on_post_move_to_device(self) -> None:
56+
pass
57+
58+
59+
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
60+
61+
def __init__(self, pl_module: Union['pl.LightningModule', _LightningPrecisionModuleWrapperBase]):
2562
"""
2663
Wraps the user's LightningModule and redirects the forward call to the appropriate
2764
method, either ``training_step``, ``validation_step`` or ``test_step``.
@@ -39,8 +76,9 @@ def __init__(self, pl_module: 'pl.LightningModule') -> None:
3976
# set the parameters_to_ignore from LightningModule.
4077
self._ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", [])
4178

42-
def forward(self, *inputs, **kwargs):
43-
trainer = self.module.trainer
79+
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
80+
lightning_module = unwrap_lightning_module(self.module)
81+
trainer = lightning_module.trainer
4482

4583
if trainer and trainer.training:
4684
output = self.module.training_step(*inputs, **kwargs)
@@ -49,7 +87,7 @@ def forward(self, *inputs, **kwargs):
4987
# it is done manually in ``LightningModule.manual_backward``
5088
# `require_backward_grad_sync` will be reset in the
5189
# ddp_plugin ``post_training_step`` hook
52-
if not self.module.automatic_optimization:
90+
if not lightning_module.automatic_optimization:
5391
trainer.model.require_backward_grad_sync = False
5492
elif trainer and trainer.testing:
5593
output = self.module.test_step(*inputs, **kwargs)
@@ -62,14 +100,14 @@ def forward(self, *inputs, **kwargs):
62100

63101
return output
64102

65-
def on_post_move_to_device(self):
103+
def on_post_move_to_device(self) -> None:
66104
pass
67105

68106

69107
def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule':
70108
model = wrapped_model
71109
if isinstance(model, (DistributedDataParallel, DataParallel)):
72-
model = model.module
73-
if isinstance(model, _LightningModuleWrapperBase):
74-
model = model.module
110+
model = unwrap_lightning_module(model.module)
111+
if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)):
112+
model = unwrap_lightning_module(model.module)
75113
return model

pytorch_lightning/plugins/precision/double.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from functools import wraps
1615
from typing import Any, Generator, List, Tuple
1716

1817
import torch
1918
import torch.nn as nn
2019
from torch.optim import Optimizer
2120

2221
from pytorch_lightning.core.lightning import LightningModule
22+
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
2323
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2424
from pytorch_lightning.utilities.apply_func import apply_to_collection
2525

2626

27-
class _DoublePrecisionPatch:
28-
"""Class to handle patching of methods in the ``LightningModule`` and subsequent teardown."""
27+
class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
28+
"""
29+
LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double
30+
(``torch.float64``) precision.
2931
30-
def __init__(self, model: nn.Module, method_name: str, old_method: Any) -> None:
31-
self.model = model
32-
self.method_name = method_name
33-
self.old_method = old_method
32+
Args:
33+
pl_module: the model to wrap
34+
"""
3435

35-
def teardown(self) -> None:
36-
setattr(self.model, self.method_name, self.old_method)
36+
def __init__(self, pl_module: LightningModule):
37+
super().__init__(pl_module)
3738

3839
@staticmethod
3940
def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
@@ -43,55 +44,63 @@ def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
4344

4445
@staticmethod
4546
def _move_float_tensors_to_double(collection: Any) -> Any:
46-
return apply_to_collection(collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision)
47-
48-
@classmethod
49-
def patch(cls, model: nn.Module, method_name: str) -> '_DoublePrecisionPatch':
50-
old_method = getattr(model, method_name)
51-
52-
@wraps(old_method)
53-
def new_method(*args: Any, **kwargs: Any) -> Any:
54-
return old_method(
55-
*_DoublePrecisionPatch._move_float_tensors_to_double(args),
56-
**_DoublePrecisionPatch._move_float_tensors_to_double(kwargs)
57-
)
58-
59-
setattr(model, method_name, new_method if callable(old_method) else old_method)
60-
return cls(model, method_name, old_method)
47+
return apply_to_collection(
48+
collection,
49+
torch.Tensor,
50+
LightningDoublePrecisionModule._to_double_precision,
51+
)
52+
53+
def training_step(self, *args: Any, **kwargs: Any) -> Any:
54+
return self.module.training_step(
55+
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
56+
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
57+
)
58+
59+
def validation_step(self, *args: Any, **kwargs: Any) -> Any:
60+
return self.module.validation_step(
61+
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
62+
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
63+
)
64+
65+
def test_step(self, *args: Any, **kwargs: Any) -> Any:
66+
return self.module.test_step(
67+
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
68+
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
69+
)
70+
71+
def predict_step(self, *args: Any, **kwargs: Any) -> Any:
72+
return self.module.predict_step(
73+
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
74+
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
75+
)
76+
77+
def forward(self, *args: Any, **kwargs: Any) -> Any:
78+
return self.module(
79+
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
80+
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
81+
)
6182

6283

6384
class DoublePrecisionPlugin(PrecisionPlugin):
64-
"""Plugin for training with double (``torch.float64``) precision."""
85+
""" Plugin for training with double (``torch.float64``) precision. """
6586

6687
precision: int = 64
6788

68-
def __init__(self) -> None:
69-
super().__init__()
70-
self.patches: List[_DoublePrecisionPatch] = []
71-
7289
def connect(
7390
self,
7491
model: nn.Module,
7592
optimizers: List[Optimizer],
7693
lr_schedulers: List[Any],
77-
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
78-
"""Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`,
79-
`predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter
80-
`optimizers` or `lr_schedulers`."""
94+
) -> Tuple[nn.Module, List['Optimizer'], List[Any]]:
95+
"""Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert
96+
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
97+
`lr_schedulers`.
98+
"""
8199
model = model.to(dtype=torch.float64)
82-
if isinstance(model, LightningModule):
83-
self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step'))
84-
self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step'))
85-
self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step'))
86-
self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step'))
87-
self.patches.append(_DoublePrecisionPatch.patch(model, 'forward'))
100+
model = LightningDoublePrecisionModule(model)
88101

89102
return super().connect(model, optimizers, lr_schedulers)
90103

91-
def post_dispatch(self) -> None:
92-
while len(self.patches) > 0:
93-
self.patches.pop().teardown()
94-
95104
@contextmanager
96105
def train_step_context(self) -> Generator[None, None, None]:
97106
"""

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,19 @@ def start_predicting(self, trainer: 'pl.Trainer') -> None:
161161
self._results = trainer.run_stage()
162162

163163
def training_step(self, *args, **kwargs):
164-
return self.lightning_module.training_step(*args, **kwargs)
164+
return self.model.training_step(*args, **kwargs)
165165

166166
def post_training_step(self):
167167
pass
168168

169169
def validation_step(self, *args, **kwargs):
170-
return self.lightning_module.validation_step(*args, **kwargs)
170+
return self.model.validation_step(*args, **kwargs)
171171

172172
def test_step(self, *args, **kwargs):
173-
return self.lightning_module.test_step(*args, **kwargs)
173+
return self.model.test_step(*args, **kwargs)
174174

175175
def predict_step(self, *args, **kwargs):
176-
return self.lightning_module.predict_step(*args, **kwargs)
176+
return self.model.predict_step(*args, **kwargs)
177177

178178
def training_step_end(self, output):
179179
return output

tests/accelerators/test_ddp.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def setup(self, stage: Optional[str] = None) -> None:
123123

124124

125125
@RunIf(min_gpus=2, min_torch="1.8.1", special=True)
126-
def test_ddp_wrapper(tmpdir):
126+
@pytest.mark.parametrize("precision", [16, 32])
127+
def test_ddp_wrapper(tmpdir, precision):
127128
"""
128129
Test parameters to ignore are carried over for DDP.
129130
"""
@@ -150,5 +151,12 @@ def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule')
150151
assert trainer.training_type_plugin.model.module._ddp_params_and_buffers_to_ignore == ('something')
151152

152153
model = CustomModel()
153-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="ddp", gpus=2, callbacks=CustomCallback())
154+
trainer = Trainer(
155+
default_root_dir=tmpdir,
156+
fast_dev_run=True,
157+
precision=precision,
158+
accelerator="ddp",
159+
gpus=2,
160+
callbacks=CustomCallback(),
161+
)
154162
trainer.fit(model)

tests/overrides/test_base.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
import torch
16+
from torch.nn import DataParallel
17+
18+
from pytorch_lightning.overrides.base import (
19+
_LightningModuleWrapperBase,
20+
_LightningPrecisionModuleWrapperBase,
21+
unwrap_lightning_module,
22+
)
23+
from tests.helpers import BoringModel
24+
25+
26+
@pytest.mark.parametrize("wrapper_class", [
27+
_LightningModuleWrapperBase,
28+
_LightningPrecisionModuleWrapperBase,
29+
])
30+
def test_wrapper_device_dtype(wrapper_class):
31+
model = BoringModel()
32+
wrapped_model = wrapper_class(model)
33+
34+
wrapped_model.to(dtype=torch.float16)
35+
assert model.dtype == torch.float16
36+
37+
38+
def test_unwrap_lightning_module():
39+
model = BoringModel()
40+
wrapped_model = _LightningPrecisionModuleWrapperBase(model)
41+
wrapped_model = _LightningModuleWrapperBase(wrapped_model)
42+
wrapped_model = DataParallel(wrapped_model)
43+
44+
assert unwrap_lightning_module(wrapped_model) == model

tests/plugins/test_double_plugin.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pickle
15+
from unittest.mock import MagicMock
16+
1417
import pytest
1518
import torch
1619
from torch.utils.data import DataLoader, Dataset
1720

1821
from pytorch_lightning import Trainer
22+
from pytorch_lightning.plugins import DoublePrecisionPlugin
1923
from tests.helpers.boring_model import BoringModel, RandomDataset
24+
from tests.helpers.runif import RunIf
2025

2126

2227
class RandomFloatIntDataset(Dataset):
@@ -121,7 +126,6 @@ def predict_dataloader(self):
121126
@pytest.mark.parametrize('boring_model', (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward))
122127
def test_double_precision(tmpdir, boring_model):
123128
model = boring_model()
124-
original_training_step = model.training_step
125129

126130
trainer = Trainer(
127131
max_epochs=2,
@@ -134,4 +138,25 @@ def test_double_precision(tmpdir, boring_model):
134138
trainer.test(model)
135139
trainer.predict(model)
136140

137-
assert model.training_step == original_training_step
141+
142+
@RunIf(min_gpus=2)
143+
def test_double_precision_ddp(tmpdir):
144+
model = DoublePrecisionBoringModel()
145+
146+
trainer = Trainer(
147+
max_epochs=1,
148+
default_root_dir=tmpdir,
149+
accelerator='ddp_spawn',
150+
gpus=2,
151+
fast_dev_run=2,
152+
precision=64,
153+
log_every_n_steps=1,
154+
)
155+
trainer.fit(model)
156+
157+
158+
def test_double_precision_pickle(tmpdir):
159+
model = BoringModel()
160+
plugin = DoublePrecisionPlugin()
161+
model, _, __ = plugin.connect(model, MagicMock(), MagicMock())
162+
pickle.dumps(model)

0 commit comments

Comments
 (0)