-
Notifications
You must be signed in to change notification settings - Fork 603
Description
SimpleFSDP consists of two major components: (1) Frontend composability with different parallelisms & distributed training techniques; (2) Backend optimization in torch.compile to overlap communication.
Frontend Composability
Dense model (llama3)
-
[Done] Parallelisms: TP/PP/CP
-
[Done] other techniques: Distributed checkpointing / mixed-precision training / meta initialization / activation checkpointing
-
[Done] Float 8 training: numeric difference (@pianpwk): We will see numeric difference in inductor mode because of triton kernel implementations. But we get bit-wise numeric equivalence in aot_eager.
MoE model (DSV3)
-
[Done] Parallelisms: TP/EP/ETP
-
[Need PoC] activation checkpointing composability: graph breaks when AC is applied (related issue: SimpleFSDP AC HOP mutation issue when tracing token dispatch #1935)
-
[In progress] Parallelism: PP (Interleave1F1B+TP): dynamic shape errors (related issue: add support for simplefsdp+ep #1529 (comment)) @laithsakka @aorenste
-
[Done] zero2-style sharding + AC composability (related PR: [simplefsdp] fix region ac in zero2-style FSDP #1970) @ruisizhang123
Backend Optimization
Manual bucketing & reordering
-
[In progress] get results on DSV3 models & merge PR (related PR: [SimpleFSDP] add manual bucketing pass #1881) @ruisizhang123
-
[In progress] numeric debugging recipe (@pianpwk @ezyang @yushangdi @ruisizhang123)
-
[Not start] allow users to specify module reordering positions via mode annotation. (@ruisizhang123 )
Auto bucketing & reordering
- [In progress] @eellison @IvanKobzarev (related PR: add auto_eager_graph_pass #1813)