-
Notifications
You must be signed in to change notification settings - Fork 603
Open
Description
Bug description
After enabling SimpleFSDP on DeepSeek models, I found the backward all-gather used for reshard-after-forward behavior is missing. This happens to SimpleFSDP+ep/tp, but SimpleFSDP standalone has backward all-gathers. More details can be found in this pr description: #1529
In SimpleFSDP, this reshard-after-forward is achieved using checkpoint. probably sth is wrong with checkpoint API in SimpleFSDP+DeepSeek.
In eager-mode, there is no backward all-gather. In compile-mode, there will be all-gathers in tlparse, but I get the error as follows at the end of training.
[rank0]:[W808 15:00:07.207032687 ProcessGroup.cpp:364] Warning: At the time of process termination, there are still 20 unwaited collective calls. Please review your program to ensure that:
[rank0]:1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
[rank0]:2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under `with allow_inflight_collective_as_graph_input_ctx():`,
[rank0]:before the output tensors of the collective are used. (function ~WorkRegistry)
cc. @tianyu-l @anijain2305 @wconstab
Versions
See above
Metadata
Metadata
Assignees
Labels
No labels