Skip to content

Commit 1ce6d11

Browse files
danielvegamyhrecleonard530
authored andcommitted
MXFP8 grouped GEMM support for torch._scaled_grouped_mm + submodule bump (pytorch#162209)
## Summary - We just landed 2d-2d support for mxfp8 grouped gemm in FBGEMM: pytorch/FBGEMM#4816 - This is needed for backward pass of mxfp8 MoE training with grouped gemms - Changes: - Add dispatching + input validation for mxfp8 grouped gemm in `torch._scaled_grouped_mm` - Add meta registration input validation for mxfp8 grouped gemm, for composability with compile - Add unit tests exercising torch._scaled_grouped_mm with mxfp8 inputs - Bump FBGEMM third party submodule to include: - pytorch/FBGEMM#4816 - pytorch/FBGEMM#4820 - pytorch/FBGEMM#4821 - pytorch/FBGEMM#4823 #### How fbgemm dependency was bumped Documenting this since I haven't found it documented elsewhere: - `cd ~/pytorch/third_party/fbgemm` - `git fetch` - `git checkout <hash>` - `cd ~/pytorch` - `git add third_party/fbgemm` ## Test plan #### Test build ``` USE_FBGEMM_GENAI=1 python -m pip install --no-build-isolation -v -e . ... Successfully installed torch-2.9.0a0+gitf5070f3 ``` [full build log](https://www.internalfb.com/phabricator/paste/view/P1933787581) #### Unit tests ``` pytest test/test_matmul_cuda.py -k test_mxfp8_scaled_grouped_mm_ ... test/test_matmul_cuda.py ......... [100%] ============================================================== 9 passed, 1668 deselected in 5.34s =============================================================== ``` Pull Request resolved: pytorch#162209 Approved by: https://github.com/ngimel
1 parent 416efed commit 1ce6d11

File tree

9 files changed

+534
-73
lines changed

9 files changed

+534
-73
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,12 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
889889
set(USE_FBGEMM_GENAI off)
890890
endif()
891891

892+
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100
893+
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a")
894+
message(WARNING "Setting USE_FBGEMM_GENAI to ON for CUDA build on SM100")
895+
set(USE_FBGEMM_GENAI ON)
896+
endif()
897+
892898
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
893899
# Eff Attention won't
894900
cmake_dependent_option(

aten/src/ATen/CMakeLists.txt

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -255,48 +255,77 @@ endif()
255255
# FBGEMM GenAI
256256
IF(USE_FBGEMM_GENAI)
257257
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
258-
set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
259-
260-
if(USE_ROCM)
261-
# Only include the kernels we want to build to avoid increasing binary size.
262-
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
263-
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
264-
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
265-
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
266-
267-
# Add additional HIPCC compiler flags for performance
268-
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
269-
-mllvm
270-
-amdgpu-coerce-illegal-types=1
271-
-mllvm
272-
-enable-post-misched=0
273-
-mllvm
274-
-greedy-reverse-local-assignment=1
275-
-fhip-new-launch-api)
276-
277-
# Only compile for gfx942 for now.
278-
# This is rather hacky, I could not figure out a clean solution :(
279-
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
280-
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
281-
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
282-
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
283-
284-
hip_add_library(
285-
fbgemm_genai STATIC
286-
${fbgemm_genai_native_rocm_hip}
287-
HIPCC_OPTIONS ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
288-
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
258+
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
259+
if(USE_CUDA)
260+
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
261+
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
262+
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
263+
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
264+
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
265+
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
266+
list(FILTER fbgemm_genai_native_cuda_cu INCLUDE REGEX ${FBGEMM_CUTLASS_KERNELS_REGEX})
267+
268+
file(GLOB_RECURSE fbgemm_genai_native_cuda_cpp
269+
"${FBGEMM_GENAI_SRCS}/common/*.cpp"
270+
)
271+
272+
# Combine all source files into a single list
273+
list(APPEND fbgemm_genai_all_sources
274+
${fbgemm_genai_native_cuda_cu}
275+
${fbgemm_genai_native_cuda_cpp}
276+
)
277+
278+
# Now, create the library and provide the sources at the same time
279+
add_library(fbgemm_genai OBJECT ${fbgemm_genai_all_sources})
289280

290281
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
291-
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
282+
283+
set(fbgemm_genai_mx8mx8bf16_grouped
284+
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
285+
)
292286

293287
target_include_directories(fbgemm_genai PUBLIC
294-
# FBGEMM version of Composable Kernel is used due to some customizations
295-
${FBGEMM_THIRD_PARTY}/composable_kernel/include
296-
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
297-
${FBGEMM_GENAI_DIR}/include/
298-
${FBGEMM_GENAI_DIR}/common/include/
288+
${FBGEMM_THIRD_PARTY}/cutlass/include
289+
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
290+
${fbgemm_genai_mx8mx8bf16_grouped}
291+
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
292+
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
299293
)
294+
else()
295+
if(USE_ROCM)
296+
# Only include the kernels we want to build to avoid increasing binary size.
297+
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
298+
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
299+
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
300+
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
301+
302+
# Add additional HIPCC compiler flags for performance
303+
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
304+
-mllvm
305+
-amdgpu-coerce-illegal-types=1
306+
-mllvm
307+
-enable-post-misched=0
308+
-mllvm
309+
-greedy-reverse-local-assignment=1
310+
-fhip-new-launch-api)
311+
312+
hip_add_library(
313+
fbgemm_genai STATIC
314+
${fbgemm_genai_native_rocm_hip}
315+
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
316+
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
317+
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
318+
319+
target_include_directories(fbgemm_genai PUBLIC
320+
# FBGEMM version of Composable Kernel is used due to some customizations
321+
${FBGEMM_THIRD_PARTY}/composable_kernel/include
322+
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
323+
${FBGEMM_THIRD_PARTY}/cutlass/include
324+
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
325+
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
326+
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
327+
)
328+
endif()
300329
endif()
301330
endif()
302331

@@ -639,6 +668,13 @@ if(USE_CUDA AND NOT USE_ROCM)
639668
add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
640669
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
641670
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
671+
672+
# Add FBGEMM_GENAI include directories for torch_ops.h
673+
if(USE_FBGEMM_GENAI)
674+
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
675+
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
676+
endif()
677+
642678
if($ENV{ATEN_STATIC_CUDA})
643679
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
644680
list(APPEND ATen_CUDA_DEPENDENCY_LIBS

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
15511551
}
15521552

15531553
namespace {
1554-
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
1554+
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
1555+
// Checks scales for 2d or 3d target tensors (`mat`).
15551556
if (mat.dim() == 2) {
15561557
TORCH_CHECK(
15571558
scale.dim() == 1,
@@ -1585,9 +1586,66 @@ namespace {
15851586
"scale must have the same first dimension as mat for arg ",
15861587
arg_idx);
15871588
}
1588-
}
1589+
}
15891590

1591+
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
1592+
// Checks scales for 2d or 3d target tensors (`mat`).
1593+
if (mat.dim() == 2) {
1594+
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
1595+
// that are converted to blocked padded format individually,
1596+
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
1597+
TORCH_CHECK(
1598+
scale.dim() == mat.dim(),
1599+
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
1600+
1601+
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
1602+
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
1603+
// * weight is transposed prior to the call, scale stays non-transposed.
1604+
bool LHS = arg_idx == 0;
1605+
int scale_dim_to_check = 0;
1606+
int mat_dim_to_check = LHS ? 0 : 1;
1607+
TORCH_CHECK(
1608+
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
1609+
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
1610+
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
1611+
} else {
1612+
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
1613+
// so we can check the exact expected scale sizes here without a d2h sync.
1614+
auto round_up = [](auto x, auto y) {
1615+
return ((x + y - 1) / y) * y;
1616+
};
1617+
1618+
// TODO: this is for 3d tensor in 2d-3d case specifically.
1619+
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
1620+
int64_t G = mat.size(0);
1621+
int64_t K = mat.size(1);
1622+
int64_t N = mat.size(2);
1623+
int64_t blocked_scale_K = round_up(K/32, 4);
1624+
int64_t blocked_scale_N = round_up(N, 128);
1625+
1626+
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
1627+
TORCH_CHECK(
1628+
scale.dim() == mat.dim() - 1,
1629+
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
1630+
);
1631+
TORCH_CHECK(
1632+
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
1633+
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
1634+
);
1635+
}
1636+
}
15901637

1638+
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
1639+
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
1640+
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
1641+
if (using_fp8_rowwise) {
1642+
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
1643+
} else if (using_mxfp8) {
1644+
_check_scales_mxfp8(mat, scale, dim, arg_idx);
1645+
} else {
1646+
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
1647+
}
1648+
}
15911649
}
15921650

15931651
Tensor
@@ -1612,8 +1670,8 @@ const std::optional<at::Tensor>& bias,
16121670
const std::optional<at::Tensor>& scale_result,
16131671
std::optional<c10::ScalarType> out_dtype,
16141672
bool use_fast_accum) {
1615-
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/false);
1616-
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+");
1673+
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
1674+
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
16171675

16181676
TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
16191677
TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
@@ -1646,10 +1704,12 @@ bool use_fast_accum) {
16461704
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
16471705
}
16481706

1649-
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
1707+
// FP8 per-tensor and per-row scaling expect fp32 scales.
1708+
// MXFP8 expects float8_e8m0fnu scales.
16501709
TORCH_CHECK(
1651-
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
1652-
"Both scale_a and scale_b must be float (fp32) tensors.");
1710+
(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) ||
1711+
(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu),
1712+
"For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors.");
16531713

16541714
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
16551715
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
@@ -1660,6 +1720,32 @@ bool use_fast_accum) {
16601720

16611721
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
16621722

1723+
#if defined(USE_FBGEMM_GENAI) && defined(USE_CUDA) && !defined(USE_ROCM)
1724+
// MXFP8 grouped GEMM dispatching
1725+
bool is_mx8mx8bf16 = (
1726+
mat_a.scalar_type() == at::kFloat8_e4m3fn && mat_b.scalar_type() == at::kFloat8_e4m3fn &&
1727+
scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu
1728+
);
1729+
TORCH_CHECK(out_dtype == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
1730+
1731+
if (is_mx8mx8bf16) {
1732+
bool b_is_3d = mat_b.dim() == 3;
1733+
bool is_2d_2d = a_is_2d && b_is_2d;
1734+
bool is_2d_3d = a_is_2d && b_is_3d;
1735+
TORCH_CHECK(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
1736+
TORCH_CHECK(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
1737+
1738+
fbgemm_gpu::mx8mx8bf16_grouped_mm(
1739+
mat_a,
1740+
mat_b,
1741+
scale_a,
1742+
scale_b,
1743+
offs.value(),
1744+
out);
1745+
return out;
1746+
}
1747+
#endif
1748+
16631749
#ifndef USE_ROCM
16641750
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
16651751
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
@@ -1691,6 +1777,7 @@ bool use_fast_accum) {
16911777
#else
16921778
TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
16931779
#endif
1780+
16941781
#endif
16951782

16961783
}

caffe2/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,10 @@ if(USE_CUDA)
16381638
# order of the libraries in the linker call matters here when statically
16391639
# linking; libculibos and cublas must be last.
16401640
target_link_libraries(torch_cuda PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
1641+
if(USE_FBGEMM_GENAI)
1642+
# Link fbgemm_genai to torch_cuda (only for (1) CUDA build for SM100).
1643+
target_link_libraries(torch_cuda PRIVATE fbgemm_genai)
1644+
endif()
16411645
endif()
16421646

16431647
# ---[ XPU library.
@@ -1759,9 +1763,10 @@ if(USE_ROCM)
17591763
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
17601764

17611765
if(USE_FBGEMM_GENAI)
1762-
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
1766+
if(USE_ROCM)
1767+
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
1768+
endif()
17631769
endif()
1764-
17651770
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
17661771
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
17671772
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})

0 commit comments

Comments
 (0)