-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
When training in parallel with the fsdp strategy, the saved checkpoints are somehow messed up. When I try to resume training from those, the epoch number is properly resumed, but the loss spikes dramatically, like as if it went back to an initial/random state. When I do the same train/checkpoint/resume loop with ddp_sharded I do not have this issue and the checkpoint resumes with a similar loss to where it left off. I further saw that when I point a model with strategy fsdp at a checkpoint saved with ddp_sharded it also resumes with a reasonable loss that is roughly at the previous level. This suggests that fsdp loads a checkpoint ok, but there is something wrong with how it saves checkpoints in parallel. Conversely when I resume using ddp_sharded from an fsdp saved checkpoint, the loss is dramatically worse as if weights were randomly initialized, further suggesting that the issue is with how weights are saved in fsdp. Knowing all of this, I am able to just switch to using ddp_sharded but this seems like a really nasty bug that could cause other people headaches so I wanted to make sure it was known.
The fix seems to be to make sure to unwrap the FullyShardedDataParallel wrapper. One key difference between the fsdp strategy implementation and the ddp_sharded strategy implementation is that ddp_sharded overrides self.lightning_module and does calls a custom unwrap_... function which unwraps the ShardedDataParallel layer prior to calling the shared unwrap_lightning_module(...) function. fsdp doesn't do any of this, and it defaults to the method implemented in ParallelStrategy.lightning_module which only calls the unwrap_lightning_module(...) function.
I am going to open a PR and link it here which makes unwrap_lightning_module(...) aware of FullyShardedDataParallel (both flavors) as well as ShardedDataParallel so that all of the strategies that use one of those wrappers would benefit. Also in the future hopefully that will be a piece of code that is noticed which needs to be modified as new wrappers are added.
To Reproduce
- Train a model in parallel that saves checkpoints for a few epochs, use
--strategy fsdp. Note the loss at the beginning and make sure it drops. - Resume a model using any strategy from one of those saved checkpoints and note that the loss is similar to the beginning of training. Based on code I would guess that the
fsdp nativestrategy, whatever that is called, is also broken. Maybe others. - Repeat 1,2 but this time use
--strategy ddp_shardedand note that the loss resumes from where it left off.
Expected behavior
Model training continues when resuming.
Environment
- CUDA:
- GPU:
- NVIDIA A100 80GB PCIe
- NVIDIA A10
- NVIDIA A10
- available: True
- version: 11.5
- GPU:
- Packages:
- numpy: 1.22.3
- pyTorch_debug: False
- pyTorch_version: 1.11.0
- pytorch-lightning: 1.6.3
- tqdm: 4.64.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.12
- version: Training metrics #100-Ubuntu SMP Fri Sep 24 14:50:10 UTC 2021