Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ at::Tensor f8f8bf16_rowwise_grouped_mm(
std::optional<at::Tensor> offsets,
at::Tensor& output);

#else

// Torch compliant MXFP8 grouped GEMM only on CUDA for now.
at::Tensor mx8mx8bf16_grouped_mm(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor offsets,
std::optional<at::Tensor> output = std::nullopt);

#endif

} // namespace fbgemm_gpu
7 changes: 0 additions & 7 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ at::Tensor f4f4bf16_grouped_stacked(
std::optional<at::Tensor> global_scale = std::nullopt,
std::optional<at::Tensor> starting_row_after_padding = std::nullopt,
bool use_mx = true);
at::Tensor mx8mx8bf16_grouped_mm(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor offsets,
std::optional<at::Tensor> output = std::nullopt);
at::Tensor f8f8bf16(
at::Tensor XQ,
at::Tensor WQ,
Expand Down
Loading