-
Notifications
You must be signed in to change notification settings - Fork 0
Rocm mx gemm #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rocm mx gemm #5
Conversation
…dation logic. Added MXFP8_HIPBLASLT recipe and adjusted mx_mm function to accommodate new kernel options.
…ASLT kernel choice for mxfp8 gemm. Enhance documentation on end-to-end performance optimization efforts for AMD GPUs.
…py to include HIPBLASLT as a valid kernel choice for MX FP8 operations.
|
bugbot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Outdated Assertion Message for MX FP8 Operations
The assertion error message "CUBLAS is the only supported kernel choice for MX FP8 operations" is incorrect. The code's assertion logic now accepts both CUBLAS and HIPBLASLT kernel choices for MX FP8 operations. This outdated message misleads users about supported kernel choices and should be updated.
torchao/prototype/mx_formats/mx_ops.py#L109-L113
ao/torchao/prototype/mx_formats/mx_ops.py
Lines 109 to 113 in 129a6d6
| assert b._elem_dtype == torch.float8_e4m3fn | |
| assert gemm_choice in ( | |
| MXGemmKernelChoice.CUBLAS, | |
| MXGemmKernelChoice.HIPBLASLT, | |
| ), "CUBLAS is the only supported kernel choice for MX FP8 operations" |
BugBot free trial expires on June 16, 2025
You have used $0.00 of your $20.00 spend limit so far. Manage your spend limit in the Cursor dashboard.
Was this report helpful? Give feedback by reacting with 👍 or 👎
… HIPBLASLT are supported kernel choices for MX FP8 operations.
…l choices for MX FP8 operations.
Co-authored-by: Copilot <[email protected]>
- Introduced `is_ROCm_mx_supported` function to verify ROCm environment compatibility for MX operations. - Added `test_hipblaslt_fp8` to validate FP8 operations using the HIPBLASLT backend, including SQNR verification for output accuracy. - Updated imports in `test_mx_mm.py` to include necessary utilities for the new test.
- Replaced `compute_sqnr` with `compute_error` for improved accuracy in error measurement. - Updated assertion to ensure output accuracy meets the specified threshold.
- Updated the function to ensure `torch.version.hip` is not None before checking the version, improving robustness against potential NoneType errors.
- Reformatted the return statement to enhance clarity and maintainability of the code.
This pull request introduces support for AMD MI355x GPUs with ROCm 6.5+ and HIPBLASLT for MX gemm operations, alongside updates to the documentation and validation logic. The changes expand the framework's compatibility and functionality for AMD hardware while maintaining support for existing NVIDIA configurations.
AMD Support Enhancements:
torchao/prototype/mx_formats/config.py:HIPBLASLTas a new kernel choice inMXGemmKernelChoiceandMXFP8_HIPBLASLTinMXLinearRecipeName._validate_gemm_kernel_choiceto include validation forHIPBLASLT, ensuring specific block size, data types, and ROCm availability.from_recipe_nameto handleMXFP8_HIPBLASLT.torchao/prototype/mx_formats/mx_ops.py:_addmm_mx_dispatchto supportHIPBLASLTfor matrix-matrix operations in addition to existing kernel choices. [1] [2]Documentation Updates:
torchao/prototype/mx_formats/README.md: