Skip to content

Commit f8271e3

Browse files
Expose mxfp8 grouped gemm in torch ops
Summary: Expose mxfp8 grouped gemm by moving function declaration to torch_ops.h Reviewed By: q10 Differential Revision: D81690096
1 parent c6a8daf commit f8271e3

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/include/fbgemm_gpu/torch_ops.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ at::Tensor f8f8bf16_rowwise_grouped_mm(
2727
std::optional<at::Tensor> offsets,
2828
at::Tensor& output);
2929

30+
#else
31+
32+
// Torch compliant MXFP8 grouped GEMM only on CUDA for now.
33+
at::Tensor mx8mx8bf16_grouped_mm(
34+
at::Tensor XQ,
35+
at::Tensor WQ,
36+
at::Tensor x_scale,
37+
at::Tensor w_scale,
38+
at::Tensor offsets,
39+
std::optional<at::Tensor> output = std::nullopt);
40+
3041
#endif
3142

3243
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@ at::Tensor f4f4bf16_grouped_stacked(
6464
std::optional<at::Tensor> global_scale = std::nullopt,
6565
std::optional<at::Tensor> starting_row_after_padding = std::nullopt,
6666
bool use_mx = true);
67-
at::Tensor mx8mx8bf16_grouped_mm(
68-
at::Tensor XQ,
69-
at::Tensor WQ,
70-
at::Tensor x_scale,
71-
at::Tensor w_scale,
72-
at::Tensor offsets,
73-
std::optional<at::Tensor> output = std::nullopt);
7467
at::Tensor f8f8bf16(
7568
at::Tensor XQ,
7669
at::Tensor WQ,

0 commit comments

Comments
 (0)