|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | from contextlib import contextmanager |
15 | | -from typing import Dict, Generator, List, Optional, Tuple, Union |
| 15 | +from typing import Dict, Generator, Optional |
16 | 16 |
|
17 | 17 | import torch |
18 | | -from torch.nn import Module |
19 | | -from torch.optim import Optimizer |
20 | 18 |
|
21 | 19 | import pytorch_lightning as pl |
22 | 20 | from pytorch_lightning.core.optimizer import LightningOptimizer |
|
35 | 33 | class DDPShardedPlugin(DDPPlugin): |
36 | 34 | """Optimizer and gradient sharded training provided by FairScale.""" |
37 | 35 |
|
38 | | - _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M |
39 | | - |
40 | | - def __init__(self, *args, **kwargs): |
41 | | - super().__init__(*args, **kwargs) |
42 | | - self._precision = None |
| 36 | + _REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M |
43 | 37 |
|
44 | 38 | def configure_ddp(self) -> None: |
45 | | - trainer = self.lightning_module.trainer |
| 39 | + self._wrap_optimizers() |
| 40 | + |
46 | 41 | if "reduce_buffer_size" not in self._ddp_kwargs: |
47 | 42 | # For multi-node training, enabling bucketing will improve performance. |
48 | 43 | self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 |
49 | 44 |
|
50 | | - [self._model], optimizers = self._setup_models_and_optimizers( |
51 | | - models=[LightningShardedDataParallel(self.model)], |
52 | | - optimizers=trainer.optimizers, |
| 45 | + self._model = ShardedDataParallel( |
| 46 | + LightningShardedDataParallel(self.model), |
| 47 | + sharded_optimizer=self.lightning_module.trainer.optimizers, |
| 48 | + **self._ddp_kwargs |
53 | 49 | ) |
54 | | - trainer.optimizers = optimizers |
55 | | - trainer.convert_to_lightning_optimizers() |
56 | | - |
57 | | - def _setup_models_and_optimizers( |
58 | | - self, models: List[Module], optimizers: List[Optimizer] |
59 | | - ) -> Tuple[List[Module], List[Optimizer]]: |
60 | | - """Wraps the model and optimizers with fairscale components. |
| 50 | + setattr(self._model, "require_backward_grad_sync", False) |
61 | 51 |
|
62 | | - Currently only one model can be setup at once. |
63 | | -
|
64 | | - Return: |
65 | | - A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module |
66 | | - and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. |
67 | | - """ |
68 | | - if len(models) > 1: |
69 | | - raise ValueError( |
70 | | - "DDPSharded only supports setting up a single model with one or several optimizers." |
71 | | - f" Got {len(models)} models." |
72 | | - ) |
73 | | - |
74 | | - optimizers = self._wrap_optimizers(optimizers) |
75 | | - model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) |
76 | | - setattr(model, "require_backward_grad_sync", False) # TODO: needed? |
77 | | - return [model], optimizers |
78 | | - |
79 | | - def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: |
| 52 | + def _reinit_optimizers_with_oss(self): |
| 53 | + optimizers = self.lightning_module.trainer.optimizers |
80 | 54 | for x, optimizer in enumerate(optimizers): |
81 | 55 | if isinstance(optimizer, LightningOptimizer): |
82 | 56 | optimizer = optimizer._optimizer |
83 | 57 | if not isinstance(optimizer, OSS): |
84 | 58 | optim_class = type(optimizer) |
85 | 59 | zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) |
86 | 60 | if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: |
87 | | - precision = self._precision or self.lightning_module.trainer.precision |
| 61 | + precision = self.lightning_module.trainer.precision |
88 | 62 | is_fp16 = precision in ("mixed", 16) |
89 | 63 | # For multi-node training, compressing the model shards in fp16 before broadcasting |
90 | 64 | # improves performance. When using PyTorch AMP, it will not degrade |
91 | 65 | # the model performance. |
92 | 66 | zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 |
93 | 67 | optimizers[x] = zero_optimizer |
94 | 68 | del optimizer |
95 | | - return optimizers |
96 | | - |
97 | | - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: |
98 | | - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: |
99 | | - return optimizers |
| 69 | + trainer = self.lightning_module.trainer |
| 70 | + trainer.optimizers = optimizers |
| 71 | + trainer.convert_to_lightning_optimizers() |
100 | 72 |
|
101 | | - return self._reinit_optimizers_with_oss(optimizers) |
| 73 | + def _wrap_optimizers(self): |
| 74 | + if self.model.trainer.state.fn != TrainerFn.FITTING: |
| 75 | + return |
| 76 | + self._reinit_optimizers_with_oss() |
102 | 77 |
|
103 | 78 | def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: |
104 | 79 | if isinstance(optimizer, LightningOptimizer): |
|
0 commit comments