|
32 | 32 | import pytorch_lightning as pl |
33 | 33 | from pytorch_lightning.core.module import LightningModule |
34 | 34 | from pytorch_lightning.core.optimizer import LightningOptimizer |
35 | | -from pytorch_lightning.overrides import LightningDistributedModule, _LightningPrecisionModuleWrapperBase |
| 35 | +from pytorch_lightning.overrides import _LightningPrecisionModuleWrapperBase, LightningDistributedModule |
36 | 36 | from pytorch_lightning.overrides.distributed import prepare_for_backward |
37 | 37 | from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE |
38 | 38 | from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment |
@@ -271,7 +271,6 @@ def _enable_model_averaging(self) -> None: |
271 | 271 | ) |
272 | 272 |
|
273 | 273 | assert self._ddp_comm_state is not None |
274 | | - # assert isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState) |
275 | 274 | self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( |
276 | 275 | period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter |
277 | 276 | ) |
@@ -341,7 +340,9 @@ def model_to_device(self) -> None: |
341 | 340 | assert self.model is not None |
342 | 341 | self.model.to(self.root_device) |
343 | 342 |
|
344 | | - def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: |
| 343 | + def reduce( |
| 344 | + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" |
| 345 | + ) -> Tensor: |
345 | 346 | """Reduces a tensor from several distributed processes to one aggregated tensor. |
346 | 347 |
|
347 | 348 | Args: |
|
0 commit comments