Skip to content

Conversation

@petrex
Copy link
Owner

@petrex petrex commented Jun 9, 2025

This pull request introduces support for AMD MI355x GPUs with ROCm 6.5+ and HIPBLASLT for MX gemm operations, alongside updates to the documentation and validation logic. The changes expand the framework's compatibility and functionality for AMD hardware while maintaining support for existing NVIDIA configurations.

AMD Support Enhancements:

  • torchao/prototype/mx_formats/config.py:

    • Added HIPBLASLT as a new kernel choice in MXGemmKernelChoice and MXFP8_HIPBLASLT in MXLinearRecipeName.
    • Updated _validate_gemm_kernel_choice to include validation for HIPBLASLT, ensuring specific block size, data types, and ROCm availability.
    • Extended from_recipe_name to handle MXFP8_HIPBLASLT.
  • torchao/prototype/mx_formats/mx_ops.py:

    • Modified _addmm_mx_dispatch to support HIPBLASLT for matrix-matrix operations in addition to existing kernel choices. [1] [2]

Documentation Updates:

  • torchao/prototype/mx_formats/README.md:
    • Updated to reflect support for AMD MI355x GPUs with ROCm 6.5+ and gfx950, including usage examples and performance optimization notes. [1] [2] [3]

Peter Y. Yeh and others added 10 commits April 16, 2025 15:59
…dation logic. Added MXFP8_HIPBLASLT recipe and adjusted mx_mm function to accommodate new kernel options.
…ASLT kernel choice for mxfp8 gemm. Enhance documentation on end-to-end performance optimization efforts for AMD GPUs.
…py to include HIPBLASLT as a valid kernel choice for MX FP8 operations.
@petrex
Copy link
Owner Author

petrex commented Jun 9, 2025

bugbot run

@petrex petrex added the enhancement New feature or request label Jun 9, 2025
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Outdated Assertion Message for MX FP8 Operations

The assertion error message "CUBLAS is the only supported kernel choice for MX FP8 operations" is incorrect. The code's assertion logic now accepts both CUBLAS and HIPBLASLT kernel choices for MX FP8 operations. This outdated message misleads users about supported kernel choices and should be updated.

torchao/prototype/mx_formats/mx_ops.py#L109-L113

assert b._elem_dtype == torch.float8_e4m3fn
assert gemm_choice in (
MXGemmKernelChoice.CUBLAS,
MXGemmKernelChoice.HIPBLASLT,
), "CUBLAS is the only supported kernel choice for MX FP8 operations"

Fix in Cursor


BugBot free trial expires on June 16, 2025
You have used $0.00 of your $20.00 spend limit so far. Manage your spend limit in the Cursor dashboard.

Was this report helpful? Give feedback by reacting with 👍 or 👎

Peter Y. Yeh and others added 8 commits June 9, 2025 10:56
… HIPBLASLT are supported kernel choices for MX FP8 operations.
- Introduced `is_ROCm_mx_supported` function to verify ROCm environment compatibility for MX operations.
- Added `test_hipblaslt_fp8` to validate FP8 operations using the HIPBLASLT backend, including SQNR verification for output accuracy.
- Updated imports in `test_mx_mm.py` to include necessary utilities for the new test.
- Replaced `compute_sqnr` with `compute_error` for improved accuracy in error measurement.
- Updated assertion to ensure output accuracy meets the specified threshold.
- Updated the function to ensure `torch.version.hip` is not None before checking the version, improving robustness against potential NoneType errors.
- Reformatted the return statement to enhance clarity and maintainability of the code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants