diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h index 2effdcc7be..169522a7d1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h @@ -27,6 +27,17 @@ at::Tensor f8f8bf16_rowwise_grouped_mm( std::optional offsets, at::Tensor& output); +#else + +// Torch compliant MXFP8 grouped GEMM only on CUDA for now. +at::Tensor mx8mx8bf16_grouped_mm( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor offsets, + std::optional output = std::nullopt); + #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 316bfd8a1b..2de66a8ddb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -64,13 +64,6 @@ at::Tensor f4f4bf16_grouped_stacked( std::optional global_scale = std::nullopt, std::optional starting_row_after_padding = std::nullopt, bool use_mx = true); -at::Tensor mx8mx8bf16_grouped_mm( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - at::Tensor offsets, - std::optional output = std::nullopt); at::Tensor f8f8bf16( at::Tensor XQ, at::Tensor WQ,