Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e697ca6
Create optimizers in DDP strategy after moving model to device
ananthsub Feb 17, 2022
a6cf5db
Update CHANGELOG.md
ananthsub Feb 17, 2022
5d362fd
Update sharded.py
ananthsub Mar 5, 2022
f2b6f81
update
ananthsub Mar 5, 2022
f3f39af
ordering
ananthsub Mar 5, 2022
f23eb99
update bagua
ananthsub Mar 8, 2022
d15378e
Update fully_sharded.py
ananthsub Mar 8, 2022
bfbbf75
Update bagua.py
ananthsub Mar 8, 2022
05c12df
update
ananthsub Mar 8, 2022
b4520fa
Update sharded.py
ananthsub Mar 8, 2022
2823d63
sharded opt
ananthsub Mar 10, 2022
f9843c9
Update ddp.py
ananthsub Mar 10, 2022
917e14a
update test
ananthsub Mar 10, 2022
ec53da6
rebase
ananthsub Mar 22, 2022
6fe6eaf
Update test_sharded_strategy.py
ananthsub Mar 22, 2022
2650067
Update test_sharded_strategy.py
ananthsub Mar 22, 2022
a98b070
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2022
56c3170
Merge branch 'master' into fix/ddp-delay-optimizer-creation
rohitgr7 May 4, 2022
e600a7b
bagua fix
rohitgr7 May 4, 2022
bd09a3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2022
2d8cd03
fix
rohitgr7 May 4, 2022
9dcf463
fix tests
rohitgr7 May 5, 2022
bbf5547
keep DDP as top wrapper
rohitgr7 May 5, 2022
7aa6777
fixes
rohitgr7 May 9, 2022
ad47fae
Merge branch 'master' into fix/ddp-delay-optimizer-creation
carmocca May 10, 2022
cbd35e2
Merge branch 'master' into fix/ddp-delay-optimizer-creation
rohitgr7 May 25, 2022
fe50c3e
Merge branch 'master' into fix/ddp-delay-optimizer-creation
carmocca May 31, 2022
7cd8545
Bad merge
carmocca May 31, 2022
713d692
Remove extra argument
carmocca May 31, 2022
9a6a0ef
Merge branch 'master' into fix/ddp-delay-optimizer-creation
carmocca May 31, 2022
5d65ad3
Merge branch 'master' into fix/ddp-delay-optimizer-creation
carmocca Jun 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938))


- Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795))


- Added all DDP params to be exposed through hpu parallel strategy ([#13067](https://github.com/PyTorchLightning/pytorch-lightning/pull/13067))


Expand Down Expand Up @@ -223,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))


- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#11952](https://github.com/PyTorchLightning/pytorch-lightning/pull/11952))


- Fixed epoch logging on train epoch end ([#13025](https://github.com/PyTorchLightning/pytorch-lightning/pull/13025))


Expand Down Expand Up @@ -652,9 +658,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with resuming from a checkpoint trained with QAT ([#11346](https://github.com/PyTorchLightning/pytorch-lightning/pull/11346))


- Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795))


## [1.5.10] - 2022-02-08

### Fixed
Expand Down
35 changes: 31 additions & 4 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.seed import reset_seed

if _BAGUA_AVAILABLE:
Expand Down Expand Up @@ -147,6 +149,33 @@ def _set_node_environment_variables(self) -> None:
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)

def setup(self, trainer: "pl.Trainer") -> None:
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

trainer_fn = trainer.state.fn

if trainer_fn == TrainerFn.FITTING:
if self._layer_sync and self.model:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()

if trainer_fn == TrainerFn.FITTING:
# set up optimizers after the module has been moved to the device
# but before the module has been wrapped
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
self._configure_bagua_model(trainer)

def _check_qadam_optimizer(self) -> None:
has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers])

Expand All @@ -155,14 +184,12 @@ def _check_qadam_optimizer(self) -> None:

self._bagua_kwargs["q_adam_optimizer"] = self.optimizers[0]

def configure_ddp(self) -> None:
def _configure_bagua_model(self, trainer: "pl.Trainer") -> None:
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
self._model = self._setup_model(model)

# start the background communication for async algorithm
assert self.lightning_module is not None
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
if trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.resume(self.model) # type: ignore

def _setup_model(self, model: Module) -> "BaguaDistributedDataParallel":
Expand Down
32 changes: 20 additions & 12 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_TORCH_GREATER_EQUAL_1_10,
_TORCH_GREATER_EQUAL_1_11,
)
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -152,24 +153,37 @@ def setup_environment(self) -> None:
super().setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
return

if self._layer_sync:
self.model = self._layer_sync.apply(self.model)
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()

if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

self.configure_ddp()
# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if _TORCH_GREATER_EQUAL_1_10 and trainer_fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._enable_model_averaging()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
Expand Down Expand Up @@ -223,12 +237,6 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._enable_model_averaging()

def _enable_model_averaging(self) -> None:
# Only called when PyTorch version >= 1.10
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
Expand Down
23 changes: 15 additions & 8 deletions pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
sync_ddp_if_available,
)
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -122,20 +123,22 @@ def _configure_launcher(self):

def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
super().setup(trainer)

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
return
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

if self._layer_sync:
self.model = self._layer_sync.apply(self.model)
self.setup_precision_plugin()

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
self.configure_ddp()
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
Expand Down Expand Up @@ -186,6 +189,10 @@ def configure_ddp(self) -> None:
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(self.lightning_module.trainer)
optimizers_to_device(self.optimizers, self.root_device)

def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
return None
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,16 @@ def setup_distributed(self) -> None:
def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)

if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
self.model = self._layer_sync.apply(self.model)
if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()
self.configure_ddp()
self.barrier()
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)
self.setup_precision_plugin()

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
Expand Down Expand Up @@ -183,6 +185,9 @@ def configure_ddp(self) -> None:
# (TODO: need to figure out solution)
self.model_to_device()

# setup optimizers after fully sharded has wrapped the lightning module
self.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module
Expand Down
50 changes: 38 additions & 12 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand All @@ -40,16 +41,41 @@ class DDPShardedStrategy(DDPStrategy):
strategy_name = "ddp_sharded"
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

self.accelerator.setup(trainer)

# move the model to the correct device
self.model_to_device()

# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()

if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def configure_ddp(self) -> None:
self._set_ddp_kwargs()
self.setup_optimizers(self.model.trainer)
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
optimizers=self.optimizers,
)
optimizers_to_device(self.optimizers, self.root_device)

def _set_ddp_kwargs(self) -> None:
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Expand All @@ -62,6 +88,12 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
Expand All @@ -79,12 +111,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
del optimizer
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand All @@ -38,9 +39,12 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
strategy_name = "ddp_sharded_spawn"

def configure_ddp(self) -> None:
# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(self.lightning_module.trainer)
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
)
optimizers_to_device(self.optimizers, self.root_device)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
Expand Down Expand Up @@ -126,9 +127,6 @@ def _configure_launcher(self):
def setup(self, trainer: "pl.Trainer") -> None:
self.start_method = "fork"
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
optimizers_to_device(self.optimizers, self.root_device)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
Expand All @@ -140,8 +138,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
else:
set_shared_parameters(self.model.module, shared_params)

self.setup_optimizers(trainer)
self.precision_plugin.connect(self.model, None, None)
self.setup_precision_plugin()

if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

def _setup_model(self, model: Module) -> Module:
return model
Expand Down
6 changes: 3 additions & 3 deletions tests/strategies/test_bagua_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_configuration(algorithm, tmpdir):
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
if algorithm == "qadam":
with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"):
trainer.strategy.configure_ddp()
trainer.strategy._configure_bagua_model(trainer)
else:
trainer.strategy.configure_ddp()
trainer.strategy._configure_bagua_model(trainer)


@RunIf(min_cuda_gpus=1, bagua=True)
Expand All @@ -111,7 +111,7 @@ def test_qadam_configuration(tmpdir):
with mock.patch(
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
trainer.strategy.configure_ddp()
trainer.strategy._configure_bagua_model(trainer)


def test_bagua_not_available(monkeypatch):
Expand Down
Loading