Skip to content

Commit 692f77b

Browse files
awaelchligithub-actions[bot]Borda
authored
Refactor LightningDataParallel (#5670)
* module * fix model access * scalar conversion * refactor * kwargs * auto unsqueeze * refactor code duplication * clean up * docs * update dp docs * changelog * generalize test * test * rename * warning cache * isort * unsqueezing test * device * device * scalar test * device * device * include coverage of overrides * clear * add deprecation test * docs * improve coverage * increase coverage * fix merge * extend test * rename base class * mention the predict method in docs * combine iteration over collection * remove override * move * line * Apply suggestions from code review * fix running stage * f401 * fix cyclic import Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 5d239cc commit 692f77b

File tree

12 files changed

+349
-349
lines changed

12 files changed

+349
-349
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
120120
- Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645))
121121

122122

123+
- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))
124+
125+
126+
- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670))
127+
128+
123129
### Removed
124130

125131
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))

pytorch_lightning/accelerators/legacy/dp_accelerator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytorch_lightning.core.lightning import LightningModule
2222
from pytorch_lightning.core.step_result import Result
2323
from pytorch_lightning.distributed import LightningDistributed
24-
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
24+
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
2525
from pytorch_lightning.utilities import AMPType
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727

@@ -74,7 +74,7 @@ def __init_torch_data_parallel(self, model):
7474

7575
# set dp device
7676
torch.cuda.set_device(self.trainer.root_gpu)
77-
model = LightningDataParallel(model, device_ids=device_ids)
77+
model = torch.nn.DataParallel(LightningParallelModule(model), device_ids=device_ids)
7878
return model
7979

8080
def __init_half_precision(self, model):
@@ -181,8 +181,10 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
181181
scheduler.load_state_dict(state)
182182

183183
def get_reference_model(self, model) -> LightningModule:
184-
if isinstance(model, LightningDataParallel):
185-
return model.module
184+
if isinstance(model, torch.nn.DataParallel):
185+
model = model.module
186+
if isinstance(model, LightningParallelModule):
187+
model = model.module
186188
return model
187189

188190
@property
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
2+
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any
15+
16+
import torch
17+
18+
from pytorch_lightning.core.lightning import LightningModule
19+
from pytorch_lightning.trainer.states import RunningStage
20+
from pytorch_lightning.utilities.warnings import WarningCache
21+
22+
warning_cache = WarningCache()
23+
24+
25+
class _LightningModuleWrapperBase(torch.nn.Module):
26+
27+
def __init__(self, pl_module: LightningModule):
28+
"""
29+
Wraps the user's LightningModule and redirects the forward call to the appropriate
30+
method, either ``training_step``, ``validation_step`` or ``test_step``.
31+
If the LightningModule is in none of the states `training`, `testing` or `validation`,
32+
the inputs will be redirected to the
33+
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
34+
Inheriting classes may also modify the inputs or outputs of forward.
35+
36+
Args:
37+
pl_module: the model to wrap
38+
"""
39+
super().__init__()
40+
self.module = pl_module
41+
42+
def forward(self, *inputs, **kwargs):
43+
running_stage = self.module.running_stage
44+
45+
if running_stage == RunningStage.TRAINING:
46+
output = self.module.training_step(*inputs, **kwargs)
47+
warn_if_output_is_none(output, "training_step")
48+
elif running_stage == RunningStage.TESTING:
49+
output = self.module.test_step(*inputs, **kwargs)
50+
warn_if_output_is_none(output, "test_step")
51+
elif running_stage == RunningStage.EVALUATING:
52+
output = self.module.validation_step(*inputs, **kwargs)
53+
warn_if_output_is_none(output, "validation_step")
54+
else:
55+
output = self.module.predict(*inputs, **kwargs)
56+
57+
return output
58+
59+
60+
def warn_if_output_is_none(output: Any, method_name: str) -> None:
61+
""" Warns user about which method returned None. """
62+
if output is None:
63+
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')

0 commit comments

Comments
 (0)