Skip to content

Commit 667229d

Browse files
committed
roofline estimator: add float8 rowwise and mxfp8 recipe support
Summary: Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --float8_recipe_name rowwise python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --mx_recipe_name mxfp8_emulated ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 19c2e7c ghstack-comment-id: 2686473047 Pull Request resolved: #1789
1 parent b9c51b7 commit 667229d

File tree

3 files changed

+249
-74
lines changed

3 files changed

+249
-74
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import pandas as pd
4848
import sympy
4949
import torch
50+
import torch.nn as nn
5051
import torch.utils.benchmark as benchmark
5152
import tqdm
5253
from torch.profiler import ProfilerActivity, profile
@@ -57,8 +58,11 @@
5758
)
5859

5960
from torchao.float8 import (
61+
Float8LinearConfig,
6062
convert_to_float8_training,
6163
)
64+
from torchao.prototype.mx_formats.config import MXLinearConfig
65+
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
6266
from torchao.testing.float8.roofline_utils import (
6367
get_float8_mem_sympy,
6468
get_gemm_time_sympy,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397
return measurement.mean
9498

9599

96-
def get_gpu_kernel_time(m, x):
100+
def get_gpu_kernel_time(m, x, grad_output):
97101
# warm up
98102
for _ in range(2):
99-
m(x).sum().backward()
103+
y = m(x)
104+
y.backward(grad_output)
100105

101106
# capture a profiling run
102107
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
103108
n_iter = 5
104109
with profile(activities=activities) as prof:
105110
for _ in range(n_iter):
106-
m(x).sum().backward()
111+
y = m(x)
112+
y.backward(grad_output)
107113
torch.cuda.synchronize()
108114
# get the gpu kernel time and aggregate it
109115
num_leaf_tensors = 1 + len(list(m.parameters()))
@@ -114,7 +120,20 @@ def get_gpu_kernel_time(m, x):
114120
return total_time_s
115121

116122

117-
def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
123+
def get_gemm_times(
124+
M,
125+
K,
126+
N,
127+
fast_accum,
128+
bf16_memory_formats,
129+
cache_filename=None,
130+
):
131+
assert bf16_memory_formats in (
132+
"row_major:col_major",
133+
"row_major:row_major",
134+
"col_major:row_major",
135+
), "unsupported"
136+
118137
# Note: this is definitely not the best way to build a cache,
119138
# but it will do for now.
120139
if cache_filename is not None:
@@ -127,15 +146,24 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127146
cache = dict()
128147
else:
129148
cache = dict()
130-
key = f"{M},{K},{N},{fast_accum}"
149+
key = f"{M},{K},{N},{fast_accum},{bf16_memory_formats}"
131150
if key in cache:
132151
return cache[key]
133152

134153
device = torch.device("cuda")
135154

136155
# bf16 time
137156
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
138-
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
157+
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
158+
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
159+
160+
if bf16_memory_formats == "row_major:col_major":
161+
w_bf16 = w_bf16.t().contiguous().t()
162+
elif bf16_memory_formats == "col_major:row_major":
163+
x_bf16 = x_bf16.t().contiguous().t()
164+
elif bf16_memory_formats == "col_major:row_major":
165+
x_bf16 = x_bf16.t().contiguous().t()
166+
139167
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
140168

141169
# f8 time
@@ -164,33 +192,50 @@ def do_matmul(A, B):
164192
def run(
165193
outfile: str,
166194
do_benchmarks: bool = True,
167-
shape_gen_name: str = "square",
195+
shape_gen_name: str = "pow2",
168196
gemm_cache_filename: Optional[str] = None,
169197
n_limit: Optional[int] = None,
198+
float8_recipe_name: Optional[str] = None,
199+
mx_recipe_name: Optional[str] = None,
200+
enable_fusion_modeling: bool = False,
170201
):
171202
"""
172203
Args:
173204
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174-
* `shape_gen_name`: `llama`, `square`, or `sweep`
205+
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
175206
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176207
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
208+
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177209
"""
178210

211+
assert not (
212+
(float8_recipe_name is not None) and (mx_recipe_name is not None)
213+
), "unsupported"
214+
if float8_recipe_name is None and mx_recipe_name is None:
215+
float8_recipe_name = "tensorwise"
216+
217+
print(f"GPU: {torch.cuda.get_device_name(0)}")
179218
print(f"do_benchmarks: {do_benchmarks}")
180219
print(f"shape_gen_name: {shape_gen_name}")
220+
print(f"float8_recipe_name: {float8_recipe_name}")
221+
print(f"mx_recipe_name: {mx_recipe_name}")
222+
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
181223

182224
M, K, N = sympy.symbols("M K N")
183225

184-
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
226+
fp8_ovhd_time_sympy = get_float8_mem_sympy(
185227
M,
186228
K,
187229
N,
230+
float8_recipe_name,
231+
mx_recipe_name,
232+
enable_fusion_modeling,
188233
)
189-
190234
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
191-
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
192235
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
236+
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
193237
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
238+
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
194239
print()
195240

196241
headers = [
@@ -217,6 +262,9 @@ def run(
217262
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218263
# we don't break them out and don't have a roofline for them.
219264
"b_fp8_e2e_spdp",
265+
# how well benchmarked gemms match roofline predicted gemms
266+
"rb_bf16_gemm_ratio",
267+
"rb_fp8_gemm_ratio",
220268
]
221269
results = []
222270

@@ -237,43 +285,72 @@ def run(
237285

238286
# if enabled, also measured observed gemm time
239287
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
288+
rb_bf16_gemm_ratio = -1
289+
rb_fp8_gemm_ratio = -1
290+
240291
if do_benchmarks:
292+
# TODO(future): make the bf16 gemm times exactly match the e2e
293+
# benchmarks, there is a slight deviation, probably related to gemm
294+
# operand memory formats/transpositions below not exactly matching
295+
# what PyTorch core is doing for `torch.mm`
296+
# input @ weight_t = output
241297
bf16_g1, f8_g1 = get_gemm_times(
242-
M_val, K_val, N_val, True, gemm_cache_filename
298+
M_val, K_val, N_val, True, "row_major:col_major", gemm_cache_filename
243299
)
300+
# grad_output @ weight = grad_input
244301
bf16_g2, f8_g2 = get_gemm_times(
245-
M_val, N_val, K_val, False, gemm_cache_filename
302+
M_val, N_val, K_val, False, "row_major:row_major", gemm_cache_filename
246303
)
304+
# input_t @ grad_output = grad_weight
247305
bf16_g3, f8_g3 = get_gemm_times(
248-
K_val, M_val, N_val, False, gemm_cache_filename
306+
K_val, M_val, N_val, False, "col_major:row_major", gemm_cache_filename
249307
)
250308
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251309
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
310+
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
311+
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252312

253313
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254314
r_fp8_ovhd_time_s = float(
255-
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
315+
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
256316
)
257317

258318
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
259319
if do_benchmarks:
260320
# create the model
261-
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
321+
if enable_fusion_modeling:
322+
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
323+
else:
324+
m_orig = (
325+
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
326+
)
262327
x = torch.randn(
263328
M_val, K_val, dtype=torch.bfloat16, device="cuda"
264329
).requires_grad_()
265330

331+
# get the gradient of the right shape
332+
grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda")
333+
266334
# get the bf16 gpu kernel time
267335
torch._dynamo.reset()
268336
m_bf16 = torch.compile(copy.deepcopy(m_orig))
269-
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
337+
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, grad_output)
270338

271339
# get the float8 dynamic scaling gpu kernel time
272340

273341
torch._dynamo.reset()
274-
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
342+
if float8_recipe_name is not None:
343+
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
344+
m_fp8_dyn = convert_to_float8_training(
345+
copy.deepcopy(m_orig), config=config
346+
)
347+
else:
348+
assert mx_recipe_name is not None
349+
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
350+
m_fp8_dyn = copy.deepcopy(m_orig)
351+
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
275352
m_fp8_dyn = torch.compile(m_fp8_dyn)
276-
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
353+
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)
277354

278355
results.append(
279356
[
@@ -295,6 +372,9 @@ def run(
295372
b_bf16_e2e_time_s,
296373
b_fp8_e2e_time_s,
297374
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
375+
# gemm ratios
376+
rb_bf16_gemm_ratio,
377+
rb_fp8_gemm_ratio,
298378
]
299379
)
300380

benchmarks/float8/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,32 @@ def get_name_to_shapes_iter(
152152
}
153153
return name_to_shapes_70b.items()
154154

155-
elif shape_gen_name == "square":
155+
elif shape_gen_name == "pow2":
156156
assert (
157157
M == K == N == None
158158
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
159159
name_to_shapes = {}
160-
min_power_of_2 = 8 # 256
161-
max_power_of_2 = 15 # 32,768
160+
min_power_of_2 = 10 # 1024
161+
max_power_of_2 = 14 # 16,384
162162
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
163163
val = 2**power_of_2
164164
name_to_shapes[idx] = val, val, val
165165
return name_to_shapes.items()
166166

167+
elif shape_gen_name == "pow2_extended":
168+
assert (
169+
M == K == N == None
170+
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
171+
name_to_shapes = {}
172+
min_power_of_2 = 10 # 1024
173+
max_power_of_2 = 14 # 16,384
174+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
175+
val1 = 2**power_of_2
176+
name_to_shapes[idx * 2] = val1, val1, val1
177+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
178+
name_to_shapes[idx * 2 + 1] = val2, val2, val2
179+
return name_to_shapes.items()
180+
167181
elif shape_gen_name == "sweep":
168182
assert (
169183
M == K == N == None

0 commit comments

Comments
 (0)