Skip to content

Commit f451be2

Browse files
[mxfp8 moe training] add triton kernel for blocked swizzled 3d weight scales
stack-info: PR: #2894, branch: danielvegamyhre/stack/63
1 parent 327db2b commit f451be2

File tree

4 files changed

+266
-98
lines changed

4 files changed

+266
-98
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
from utils import benchmark_cuda_function_in_microseconds
15+
16+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
17+
torch_to_blocked_per_group_3d,
18+
triton_mx_block_rearrange_per_group_3d,
19+
)
20+
21+
device = torch.device("cuda")
22+
23+
# Needed since changing args to function causes recompiles
24+
torch._dynamo.config.cache_size_limit = 1000
25+
26+
27+
@dataclass(frozen=True)
28+
class ExperimentConfig:
29+
input_shape: tuple[int]
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentResult:
34+
torch_time_us: float
35+
triton_time_us: float
36+
torch_mem_bw_gbps: float
37+
triton_mem_bw_gbps: float
38+
39+
40+
@dataclass(frozen=True)
41+
class Experiment:
42+
config: ExperimentConfig
43+
result: ExperimentResult
44+
45+
46+
def get_configs() -> List[ExperimentConfig]:
47+
# Llama4 shapes. Input activations are scaled along K dim.
48+
block_size = 32
49+
input_shapes = [
50+
# w1, w3 scaled along K (fwd)
51+
(1, 8192, 5120 // block_size),
52+
(2, 8192, 5120 // block_size),
53+
(4, 8192, 5120 // block_size),
54+
(8, 8192, 5120 // block_size),
55+
(16, 8192, 5120 // block_size),
56+
# w2 scaled along K (fwd)
57+
(1, 5120, 8192 // block_size),
58+
(2, 5120, 8192 // block_size),
59+
(4, 5120, 8192 // block_size),
60+
(8, 5120, 8192 // block_size),
61+
(16, 5120, 8192 // block_size),
62+
]
63+
configs = []
64+
for shape in input_shapes:
65+
configs.append(
66+
ExperimentConfig(
67+
input_shape=shape,
68+
)
69+
)
70+
return configs
71+
72+
73+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
74+
input_tensor = torch.randint(
75+
low=0,
76+
high=256,
77+
size=config.input_shape,
78+
dtype=torch.uint8,
79+
device=device,
80+
)
81+
82+
def warmup(fn, *args, **kwargs):
83+
for _ in range(5):
84+
fn(*args, **kwargs)
85+
86+
E, N, K = config.input_shape
87+
88+
# bench torch
89+
compiled_run_torch = torch.compile(torch_to_blocked_per_group_3d)
90+
warmup(compiled_run_torch, input_tensor)
91+
torch_time_us = benchmark_cuda_function_in_microseconds(
92+
compiled_run_torch,
93+
input_tensor,
94+
)
95+
96+
# bench triton
97+
triton_out_scales = triton_mx_block_rearrange_per_group_3d(input_tensor)
98+
warmup(triton_mx_block_rearrange_per_group_3d, input_tensor)
99+
triton_time_us = benchmark_cuda_function_in_microseconds(
100+
triton_mx_block_rearrange_per_group_3d,
101+
input_tensor,
102+
)
103+
104+
# mem bw calculations
105+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
106+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
107+
108+
read_bytes = input_tensor.numel() * bytes_per_input_el
109+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
110+
111+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
112+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
113+
114+
return ExperimentResult(
115+
torch_time_us=torch_time_us,
116+
triton_time_us=triton_time_us,
117+
torch_mem_bw_gbps=torch_mem_bw_gbps,
118+
triton_mem_bw_gbps=triton_mem_bw_gbps,
119+
)
120+
121+
122+
def print_results(experiments: List[Experiment]):
123+
headers = [
124+
"input_shape",
125+
"torch_time_us",
126+
"triton_time_us",
127+
"torch_mem_bw_gbps",
128+
"triton_mem_bw_gbps",
129+
"triton_speedup",
130+
]
131+
rows = []
132+
for experiment in experiments:
133+
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]}, {experiment.config.input_shape[2]})"
134+
rows.append(
135+
[
136+
input_shape,
137+
experiment.result.torch_time_us,
138+
experiment.result.triton_time_us,
139+
round(experiment.result.torch_mem_bw_gbps, 3),
140+
round(experiment.result.triton_mem_bw_gbps, 3),
141+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
142+
]
143+
)
144+
print(tabulate(rows, headers=headers))
145+
146+
147+
def main():
148+
torch.random.manual_seed(123)
149+
configs = get_configs()
150+
results = []
151+
for config in tqdm(configs):
152+
result = run_experiment(config)
153+
results.append(Experiment(config=config, result=result))
154+
155+
# Use Tabulate to print results
156+
print_results(results)
157+
158+
159+
if __name__ == "__main__":
160+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
2525
compute_per_group_blocked_scale_offsets,
2626
torch_to_blocked_per_group_2d,
27+
torch_to_blocked_per_group_3d,
2728
triton_mx_block_rearrange_per_group_2d,
29+
triton_mx_block_rearrange_per_group_3d,
2830
)
2931
from torchao.prototype.moe_training.utils import (
3032
_is_column_major,
@@ -240,3 +242,27 @@ def test_mxfp8_per_group_blocked_scales_2d(
240242
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
241243
"blocked scales not equal"
242244
)
245+
246+
247+
@skip_if_rocm("ROCm enablement in progress")
248+
@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)])
249+
def test_mxfp8_per_group_blocked_scales_3d(
250+
e: int,
251+
n: int,
252+
k: int,
253+
):
254+
device = "cuda"
255+
block_size = 32
256+
weights = torch.randn(e, n, k // block_size, device=device)
257+
weight_scales, _ = to_mx(
258+
weights, elem_dtype=torch.float8_e4m3fn, block_size=block_size
259+
)
260+
261+
# torch reference
262+
ref_out_scales = torch_to_blocked_per_group_3d(weight_scales)
263+
264+
# triton kernel
265+
triton_out_scales = triton_mx_block_rearrange_per_group_3d(weight_scales)
266+
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
267+
"blocked scales not equal"
268+
)

0 commit comments

Comments
 (0)