diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 63ac7f5105945..9712b5356091f 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -73,7 +73,6 @@ def _setup_models_and_optimizers( optimizers = self._wrap_optimizers(optimizers) model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) - setattr(model, "require_backward_grad_sync", False) # TODO: needed? return [model], optimizers def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5d48c489a37e8..9503ffb951abb 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -64,7 +64,6 @@ def _setup_models_and_optimizers( optimizers = self._wrap_optimizers(optimizers) model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) - setattr(model, "require_backward_grad_sync", False) # TODO: needed? return [model], optimizers def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: