-
Notifications
You must be signed in to change notification settings - Fork 681
Add 2d-2d support to MXFP8 Grouped GEMM #4816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
dc99695 to
069b945
Compare
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
069b945 to
b008808
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
b008808 to
953bbe8
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
953bbe8 to
484cea5
Compare
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
484cea5 to
ec96115
Compare
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
ec96115 to
5d9c6dd
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
5d9c6dd to
b0c77bb
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
b0c77bb to
018b5d2
Compare
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
018b5d2 to
66e7a12
Compare
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
66e7a12 to
08430a6
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
b34875f to
bc0c554
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
bc0c554 to
812d712
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
812d712 to
4b3651c
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
4b3651c to
4914467
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
Summary: Pull Request resolved: pytorch#4816 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
4914467 to
0755452
Compare
Summary: X-link: facebookresearch/FBGEMM#1846 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
0755452 to
f809b74
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81362680 |
|
This pull request has been merged in c6a8daf. |
…ump (#162209) ## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: #162209 Approved by: https://github.com/ngimel
…ump (pytorch#162209) ## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: pytorch#162209 Approved by: https://github.com/ngimel
…ump (pytorch#162209) ## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: pytorch#162209 Approved by: https://github.com/ngimel
…ump (pytorch#162209) ## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: pytorch#162209 Approved by: https://github.com/ngimel
…ump (pytorch#162209) ## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: pytorch#162209 Approved by: https://github.com/ngimel
…ump (pytorch#162209) ## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: pytorch#162209 Approved by: https://github.com/ngimel
Summary:
MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API
Reviewed By: ngimel, cthi
Differential Revision: D81362680