Skip to content

Commit a56882d

Browse files
cthifacebook-github-bot
authored andcommitted
Enable USE_FBGEMM_GENAI (#4703)
Summary: X-link: pytorch/pytorch#160676 Pull Request resolved: #4703 X-link: facebookresearch/FBGEMM#1728 In this diff we enable the support for the new FBGEMM backed FP8 `torch._scaled_grouped_mm` on ROCm. For now we only enable support for `gfx942` as that is what we have thoroughly tested performance and correctness on. Reviewed By: drisspg Differential Revision: D79564024 fbshipit-source-id: bf2aa1a3eee43d0e47e9ba1e5514152e502da35f
1 parent 17ee9d0 commit a56882d

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ def quantize(self, x, wq, w_scale, m_sizes):
12631263
out = torch.empty(
12641264
(xq.shape[0], wq.shape[1]), dtype=torch.bfloat16, device=xq.device
12651265
)
1266+
x_scale = x_scale.view(x_scale.shape[0])
12661267
return xq, wq, x_scale, w_scale, offsets, out
12671268

12681269
def compute(self, xq, wq, x_scale, w_scale, offsets, out):
@@ -1287,6 +1288,48 @@ def cuda(self) -> bool:
12871288
return False
12881289

12891290

1291+
@register_quantize_op
1292+
class ScaledGroupedMMRowwise(FP8StackedGroupedGemmTorch):
1293+
def __init__(self):
1294+
self.fast_accum = True
1295+
self.torch_compile = False
1296+
1297+
def compute(self, xq, wq, x_scale, w_scale, offsets, _):
1298+
if self.torch_compile:
1299+
f = torch.compile(
1300+
torch._scaled_grouped_mm,
1301+
options={
1302+
"max_autotune": True,
1303+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
1304+
},
1305+
)
1306+
else:
1307+
f = torch._scaled_grouped_mm
1308+
1309+
return f(
1310+
xq,
1311+
wq.transpose(-2, -1),
1312+
offs=offsets,
1313+
out_dtype=torch.bfloat16,
1314+
scale_a=x_scale,
1315+
scale_b=w_scale,
1316+
scale_result=None,
1317+
use_fast_accum=self.fast_accum,
1318+
)
1319+
1320+
@property
1321+
def name(self) -> str:
1322+
return "scaled_grouped_mm_rowwise"
1323+
1324+
@property
1325+
def hip(self) -> bool:
1326+
return True
1327+
1328+
@property
1329+
def cuda(self) -> bool:
1330+
return True
1331+
1332+
12901333
@register_quantize_op
12911334
class FP8StackedGroupwiseGroupedGemm(QuantizeOpBase):
12921335
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <tuple>
1313

1414
#include <ATen/core/Tensor.h>
15-
#include <ATen/hip/HIPContext.h>
1615
#include <c10/hip/HIPStream.h>
1716

1817
#include "ck/ck.hpp"

0 commit comments

Comments
 (0)