4747import pandas as pd
4848import sympy
4949import torch
50+ import torch .nn as nn
5051import torch .utils .benchmark as benchmark
5152import tqdm
5253from torch .profiler import ProfilerActivity , profile
5758)
5859
5960from torchao .float8 import (
61+ Float8LinearConfig ,
6062 convert_to_float8_training ,
6163)
64+ from torchao .prototype .mx_formats .config import MXLinearConfig
65+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
6266from torchao .testing .float8 .roofline_utils import (
6367 get_float8_mem_sympy ,
6468 get_gemm_time_sympy ,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397 return measurement .mean
9498
9599
96- def get_gpu_kernel_time (m , x ):
100+ def get_gpu_kernel_time (m , x , grad_output ):
97101 # warm up
98102 for _ in range (2 ):
99- m (x ).sum ().backward ()
103+ y = m (x )
104+ y .backward (grad_output )
100105
101106 # capture a profiling run
102107 activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
103108 n_iter = 5
104109 with profile (activities = activities ) as prof :
105110 for _ in range (n_iter ):
106- m (x ).sum ().backward ()
111+ y = m (x )
112+ y .backward (grad_output )
107113 torch .cuda .synchronize ()
108114 # get the gpu kernel time and aggregate it
109115 num_leaf_tensors = 1 + len (list (m .parameters ()))
@@ -114,7 +120,20 @@ def get_gpu_kernel_time(m, x):
114120 return total_time_s
115121
116122
117- def get_gemm_times (M , K , N , fast_accum , cache_filename = None ):
123+ def get_gemm_times (
124+ M ,
125+ K ,
126+ N ,
127+ fast_accum ,
128+ bf16_memory_formats ,
129+ cache_filename = None ,
130+ ):
131+ assert bf16_memory_formats in (
132+ "row_major:col_major" ,
133+ "row_major:row_major" ,
134+ "col_major:row_major" ,
135+ ), "unsupported"
136+
118137 # Note: this is definitely not the best way to build a cache,
119138 # but it will do for now.
120139 if cache_filename is not None :
@@ -127,15 +146,24 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127146 cache = dict ()
128147 else :
129148 cache = dict ()
130- key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
149+ key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } , { bf16_memory_formats } "
131150 if key in cache :
132151 return cache [key ]
133152
134153 device = torch .device ("cuda" )
135154
136155 # bf16 time
137156 x_bf16 = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
138- w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device ).t ().contiguous ().t ()
157+ # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
158+ w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device )
159+
160+ if bf16_memory_formats == "row_major:col_major" :
161+ w_bf16 = w_bf16 .t ().contiguous ().t ()
162+ elif bf16_memory_formats == "col_major:row_major" :
163+ x_bf16 = x_bf16 .t ().contiguous ().t ()
164+ elif bf16_memory_formats == "col_major:row_major" :
165+ x_bf16 = x_bf16 .t ().contiguous ().t ()
166+
139167 bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
140168
141169 # f8 time
@@ -164,33 +192,50 @@ def do_matmul(A, B):
164192def run (
165193 outfile : str ,
166194 do_benchmarks : bool = True ,
167- shape_gen_name : str = "square " ,
195+ shape_gen_name : str = "pow2 " ,
168196 gemm_cache_filename : Optional [str ] = None ,
169197 n_limit : Optional [int ] = None ,
198+ float8_recipe_name : Optional [str ] = None ,
199+ mx_recipe_name : Optional [str ] = None ,
200+ enable_fusion_modeling : bool = False ,
170201):
171202 """
172203 Args:
173204 * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174- * `shape_gen_name`: `llama`, `square `, or `sweep`
205+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended `, or `sweep`
175206 * `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176207 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
208+ * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177209 """
178210
211+ assert not (
212+ (float8_recipe_name is not None ) and (mx_recipe_name is not None )
213+ ), "unsupported"
214+ if float8_recipe_name is None and mx_recipe_name is None :
215+ float8_recipe_name = "tensorwise"
216+
217+ print (f"GPU: { torch .cuda .get_device_name (0 )} " )
179218 print (f"do_benchmarks: { do_benchmarks } " )
180219 print (f"shape_gen_name: { shape_gen_name } " )
220+ print (f"float8_recipe_name: { float8_recipe_name } " )
221+ print (f"mx_recipe_name: { mx_recipe_name } " )
222+ print (f"enable_fusion_modeling: { enable_fusion_modeling } " )
181223
182224 M , K , N = sympy .symbols ("M K N" )
183225
184- fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
226+ fp8_ovhd_time_sympy = get_float8_mem_sympy (
185227 M ,
186228 K ,
187229 N ,
230+ float8_recipe_name ,
231+ mx_recipe_name ,
232+ enable_fusion_modeling ,
188233 )
189-
190234 bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191- print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192235 fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
236+ print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
193237 print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
238+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
194239 print ()
195240
196241 headers = [
@@ -217,6 +262,9 @@ def run(
217262 # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218263 # we don't break them out and don't have a roofline for them.
219264 "b_fp8_e2e_spdp" ,
265+ # how well benchmarked gemms match roofline predicted gemms
266+ "rb_bf16_gemm_ratio" ,
267+ "rb_fp8_gemm_ratio" ,
220268 ]
221269 results = []
222270
@@ -237,43 +285,72 @@ def run(
237285
238286 # if enabled, also measured observed gemm time
239287 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
288+ rb_bf16_gemm_ratio = - 1
289+ rb_fp8_gemm_ratio = - 1
290+
240291 if do_benchmarks :
292+ # TODO(future): make the bf16 gemm times exactly match the e2e
293+ # benchmarks, there is a slight deviation, probably related to gemm
294+ # operand memory formats/transpositions below not exactly matching
295+ # what PyTorch core is doing for `torch.mm`
296+ # input @ weight_t = output
241297 bf16_g1 , f8_g1 = get_gemm_times (
242- M_val , K_val , N_val , True , gemm_cache_filename
298+ M_val , K_val , N_val , True , "row_major:col_major" , gemm_cache_filename
243299 )
300+ # grad_output @ weight = grad_input
244301 bf16_g2 , f8_g2 = get_gemm_times (
245- M_val , N_val , K_val , False , gemm_cache_filename
302+ M_val , N_val , K_val , False , "row_major:row_major" , gemm_cache_filename
246303 )
304+ # input_t @ grad_output = grad_weight
247305 bf16_g3 , f8_g3 = get_gemm_times (
248- K_val , M_val , N_val , False , gemm_cache_filename
306+ K_val , M_val , N_val , False , "col_major:row_major" , gemm_cache_filename
249307 )
250308 b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251309 b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
310+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
311+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252312
253313 # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254314 r_fp8_ovhd_time_s = float (
255- fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
315+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
256316 )
257317
258318 b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259319 if do_benchmarks :
260320 # create the model
261- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
321+ if enable_fusion_modeling :
322+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
323+ else :
324+ m_orig = (
325+ nn .Sequential (nn .Linear (K_val , N_val , bias = False )).cuda ().bfloat16 ()
326+ )
262327 x = torch .randn (
263328 M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264329 ).requires_grad_ ()
265330
331+ # get the gradient of the right shape
332+ grad_output = torch .randn (N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
333+
266334 # get the bf16 gpu kernel time
267335 torch ._dynamo .reset ()
268336 m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
337+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , grad_output )
270338
271339 # get the float8 dynamic scaling gpu kernel time
272340
273341 torch ._dynamo .reset ()
274- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
342+ if float8_recipe_name is not None :
343+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
344+ m_fp8_dyn = convert_to_float8_training (
345+ copy .deepcopy (m_orig ), config = config
346+ )
347+ else :
348+ assert mx_recipe_name is not None
349+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
350+ m_fp8_dyn = copy .deepcopy (m_orig )
351+ swap_linear_with_mx_linear (m_fp8_dyn , config = config )
275352 m_fp8_dyn = torch .compile (m_fp8_dyn )
276- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
353+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , grad_output )
277354
278355 results .append (
279356 [
@@ -295,6 +372,9 @@ def run(
295372 b_bf16_e2e_time_s ,
296373 b_fp8_e2e_time_s ,
297374 b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
375+ # gemm ratios
376+ rb_bf16_gemm_ratio ,
377+ rb_fp8_gemm_ratio ,
298378 ]
299379 )
300380
0 commit comments