You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (pytorch#162209)
## Summary
- We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816
- This is needed for backward pass of mxfp8 MoE training with grouped gemms
- Changes:
- Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm`
- Add meta registration input validation for mxfp8 grouped gemm, for composability with compile
- Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs
- Bump FBGEMM third party submodule to include:
- pytorch/FBGEMM#4816
- pytorch/FBGEMM#4820
- pytorch/FBGEMM#4821
- pytorch/FBGEMM#4823
#### How fbgemm dependency was bumped
Documenting this since I haven't found it documented elsewhere:
- `cd ~/pytorch/third_party/fbgemm`
- `git fetch`
- `git checkout <hash>`
- `cd ~/pytorch`
- `git add third_party/fbgemm`
## Test plan
#### Test build
```
USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e .
...
Successfully installed torch-2.9.0a0+gitf5070f3
```
[full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581)
#### Unit tests
```
pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_
...
test/test_matmul_cuda.py ......... [100%]
============================================================== 9 passed, 1668 deselected in 5.34s ===============================================================
```
Pull Request resolved: pytorch#162209
Approved by: https://github.com/ngimel
// Checks scales for 2d or 3d target tensors (`mat`).
1593
+
if (mat.dim() == 2) {
1594
+
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
1595
+
// that are converted to blocked padded format individually,
1596
+
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
1597
+
TORCH_CHECK(
1598
+
scale.dim() == mat.dim(),
1599
+
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
1611
+
} else {
1612
+
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
1613
+
// so we can check the exact expected scale sizes here without a d2h sync.
1614
+
auto round_up = [](auto x, auto y) {
1615
+
return ((x + y - 1) / y) * y;
1616
+
};
1617
+
1618
+
// TODO: this is for 3d tensor in 2d-3d case specifically.
1619
+
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
1620
+
int64_t G = mat.size(0);
1621
+
int64_t K = mat.size(1);
1622
+
int64_t N = mat.size(2);
1623
+
int64_t blocked_scale_K = round_up(K/32, 4);
1624
+
int64_t blocked_scale_N = round_up(N, 128);
1625
+
1626
+
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
1627
+
TORCH_CHECK(
1628
+
scale.dim() == mat.dim() - 1,
1629
+
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
1630
+
);
1631
+
TORCH_CHECK(
1632
+
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
1633
+
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
0 commit comments