|
27 | 27 | import pytorch_lightning as pl |
28 | 28 | from pytorch_lightning.overrides import LightningDistributedModule |
29 | 29 | from pytorch_lightning.overrides.distributed import prepare_for_backward |
30 | | -from pytorch_lightning.overrides.torch_distributed import broadcast_object_list |
31 | 30 | from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment |
32 | 31 | from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO |
33 | 32 | from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin |
34 | 33 | from pytorch_lightning.trainer.states import TrainerFn |
35 | | -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, rank_zero_warn |
| 34 | +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn |
36 | 35 | from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device |
37 | 36 | from pytorch_lightning.utilities.cloud_io import atomic_save |
38 | 37 | from pytorch_lightning.utilities.cloud_io import load as pl_load |
@@ -238,15 +237,13 @@ def pre_configure_ddp(self): |
238 | 237 | # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. |
239 | 238 | # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. |
240 | 239 | self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) |
241 | | - # todo: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization |
242 | | - if ( |
243 | | - _TORCH_GREATER_EQUAL_1_7 |
244 | | - and not self.lightning_module.automatic_optimization |
245 | | - and not self._ddp_kwargs.get("find_unused_parameters", False) |
| 240 | + if not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( |
| 241 | + "find_unused_parameters", False |
246 | 242 | ): |
| 243 | + # TODO: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization |
247 | 244 | rank_zero_warn( |
248 | | - "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " |
249 | | - "to properly work with DDP." |
| 245 | + "From PyTorch 1.7.0, Lightning `manual_optimization` needs to set `find_unused_parameters=True` to" |
| 246 | + " properly work with DDP. Using `find_unused_parameters=True`." |
250 | 247 | ) |
251 | 248 | self._ddp_kwargs["find_unused_parameters"] = True |
252 | 249 |
|
@@ -323,7 +320,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: |
323 | 320 | obj = [obj] |
324 | 321 | if self.global_rank != src: |
325 | 322 | obj = [None] |
326 | | - broadcast_object_list(obj, src, group=_group.WORLD) |
| 323 | + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) |
327 | 324 | return obj[0] |
328 | 325 |
|
329 | 326 | def model_to_device(self): |
|
0 commit comments