4747import pandas as pd
4848import sympy
4949import torch
50+ import torch .nn as nn
5051import torch .utils .benchmark as benchmark
5152import tqdm
5253from torch .profiler import ProfilerActivity , profile
5758)
5859
5960from torchao .float8 import (
61+ Float8LinearConfig ,
6062 convert_to_float8_training ,
6163)
64+ from torchao .prototype .mx_formats .config import MXLinearConfig
65+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
6266from torchao .testing .float8 .roofline_utils import (
6367 get_float8_mem_sympy ,
6468 get_gemm_time_sympy ,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397 return measurement .mean
9498
9599
96- def get_gpu_kernel_time (m , x ):
100+ def get_gpu_kernel_time (m , x , grad_output ):
97101 # warm up
98102 for _ in range (2 ):
99- m (x ).sum ().backward ()
103+ y = m (x )
104+ y .backward (grad_output )
100105
101106 # capture a profiling run
102107 activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
103108 n_iter = 5
104109 with profile (activities = activities ) as prof :
105110 for _ in range (n_iter ):
106- m (x ).sum ().backward ()
111+ y = m (x )
112+ y .backward (grad_output )
107113 torch .cuda .synchronize ()
108114 # get the gpu kernel time and aggregate it
109115 num_leaf_tensors = 1 + len (list (m .parameters ()))
@@ -114,7 +120,22 @@ def get_gpu_kernel_time(m, x):
114120 return total_time_s
115121
116122
117- def get_gemm_times (M , K , N , fast_accum , cache_filename = None ):
123+ def get_gemm_times (
124+ M ,
125+ K ,
126+ N ,
127+ fast_accum ,
128+ bf16_memory_formats ,
129+ float8_recipe_name ,
130+ mx_recipe_name ,
131+ cache_filename = None ,
132+ ):
133+ assert bf16_memory_formats in (
134+ "row_major:col_major" ,
135+ "row_major:row_major" ,
136+ "col_major:row_major" ,
137+ ), "unsupported"
138+
118139 # Note: this is definitely not the best way to build a cache,
119140 # but it will do for now.
120141 if cache_filename is not None :
@@ -127,23 +148,38 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127148 cache = dict ()
128149 else :
129150 cache = dict ()
130- key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
151+ key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } , { bf16_memory_formats } "
131152 if key in cache :
132153 return cache [key ]
133154
134155 device = torch .device ("cuda" )
135156
136157 # bf16 time
137158 x_bf16 = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
138- w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device ).t ().contiguous ().t ()
159+ # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
160+ w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device )
161+
162+ if bf16_memory_formats == "row_major:col_major" :
163+ w_bf16 = w_bf16 .t ().contiguous ().t ()
164+ elif bf16_memory_formats == "col_major:row_major" :
165+ x_bf16 = x_bf16 .t ().contiguous ().t ()
166+ elif bf16_memory_formats == "col_major:row_major" :
167+ x_bf16 = x_bf16 .t ().contiguous ().t ()
168+
139169 bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
140170
141171 # f8 time
142172 d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
143173 A = torch .zeros (M , K , device = device , dtype = d1 )
144174 B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
145- scale_a = torch .tensor ([1.0 ], device = device )
146- scale_b = torch .tensor ([1.0 ], device = device )
175+ if float8_recipe_name == "tensorwise" :
176+ scale_a = torch .tensor ([1.0 ], device = device )
177+ scale_b = torch .tensor ([1.0 ], device = device )
178+ elif float8_recipe_name == "rowwise" :
179+ scale_a = torch .ones (M , 1 , device = device )
180+ scale_b = torch .ones (1 , N , device = device )
181+ else :
182+ assert False , "TODO add mx gemm here"
147183
148184 def do_matmul (A , B ):
149185 return torch ._scaled_mm (
@@ -164,33 +200,52 @@ def do_matmul(A, B):
164200def run (
165201 outfile : str ,
166202 do_benchmarks : bool = True ,
167- shape_gen_name : str = "square " ,
203+ shape_gen_name : str = "pow2 " ,
168204 gemm_cache_filename : Optional [str ] = None ,
169205 n_limit : Optional [int ] = None ,
206+ float8_recipe_name : Optional [str ] = None ,
207+ mx_recipe_name : Optional [str ] = None ,
208+ enable_fusion_modeling : bool = False ,
170209):
171210 """
172211 Args:
173212 * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174- * `shape_gen_name`: `llama`, `square `, or `sweep`
213+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended `, or `sweep`
175214 * `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176215 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
216+ * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177217 """
178218
219+ assert not (
220+ (float8_recipe_name is not None ) and (mx_recipe_name is not None )
221+ ), "unsupported"
222+ if float8_recipe_name is None and mx_recipe_name is None :
223+ float8_recipe_name = "tensorwise"
224+
225+ print (f"GPU: { torch .cuda .get_device_name (0 )} " )
179226 print (f"do_benchmarks: { do_benchmarks } " )
180227 print (f"shape_gen_name: { shape_gen_name } " )
228+ print (f"float8_recipe_name: { float8_recipe_name } " )
229+ print (f"mx_recipe_name: { mx_recipe_name } " )
230+ print (f"enable_fusion_modeling: { enable_fusion_modeling } " )
181231
182232 M , K , N = sympy .symbols ("M K N" )
183233
184- fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
234+ fp8_ovhd_time_sympy = get_float8_mem_sympy (
185235 M ,
186236 K ,
187237 N ,
238+ float8_recipe_name ,
239+ mx_recipe_name ,
240+ enable_fusion_modeling ,
241+ )
242+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None )
243+ fp8_gemm_time_sympy = get_gemm_time_sympy (
244+ M , K , N , torch .float8_e4m3fn , mx_recipe_name
188245 )
189-
190- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191246 print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193247 print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
248+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
194249 print ()
195250
196251 headers = [
@@ -217,6 +272,9 @@ def run(
217272 # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218273 # we don't break them out and don't have a roofline for them.
219274 "b_fp8_e2e_spdp" ,
275+ # how well benchmarked gemms match roofline predicted gemms
276+ "rb_bf16_gemm_ratio" ,
277+ "rb_fp8_gemm_ratio" ,
220278 ]
221279 results = []
222280
@@ -237,43 +295,93 @@ def run(
237295
238296 # if enabled, also measured observed gemm time
239297 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
298+ rb_bf16_gemm_ratio = - 1
299+ rb_fp8_gemm_ratio = - 1
300+
240301 if do_benchmarks :
302+ # TODO(future): make the bf16 gemm times exactly match the e2e
303+ # benchmarks, there is a slight deviation, probably related to gemm
304+ # operand memory formats/transpositions below not exactly matching
305+ # what PyTorch core is doing for `torch.mm`
306+ # input @ weight_t = output
241307 bf16_g1 , f8_g1 = get_gemm_times (
242- M_val , K_val , N_val , True , gemm_cache_filename
308+ M_val ,
309+ K_val ,
310+ N_val ,
311+ True ,
312+ "row_major:col_major" ,
313+ float8_recipe_name ,
314+ mx_recipe_name ,
315+ gemm_cache_filename ,
243316 )
317+ # grad_output @ weight = grad_input
244318 bf16_g2 , f8_g2 = get_gemm_times (
245- M_val , N_val , K_val , False , gemm_cache_filename
319+ M_val ,
320+ N_val ,
321+ K_val ,
322+ False ,
323+ "row_major:row_major" ,
324+ float8_recipe_name ,
325+ mx_recipe_name ,
326+ gemm_cache_filename ,
246327 )
328+ # input_t @ grad_output = grad_weight
247329 bf16_g3 , f8_g3 = get_gemm_times (
248- K_val , M_val , N_val , False , gemm_cache_filename
330+ K_val ,
331+ M_val ,
332+ N_val ,
333+ False ,
334+ "col_major:row_major" ,
335+ float8_recipe_name ,
336+ mx_recipe_name ,
337+ gemm_cache_filename ,
249338 )
250339 b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251340 b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
341+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
342+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252343
253344 # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254345 r_fp8_ovhd_time_s = float (
255- fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
346+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
256347 )
257348
258349 b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259350 if do_benchmarks :
260351 # create the model
261- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
352+ if enable_fusion_modeling :
353+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
354+ else :
355+ m_orig = (
356+ nn .Sequential (nn .Linear (K_val , N_val , bias = False )).cuda ().bfloat16 ()
357+ )
262358 x = torch .randn (
263359 M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264360 ).requires_grad_ ()
265361
362+ # get the gradient of the right shape
363+ grad_output = torch .randn (N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
364+
266365 # get the bf16 gpu kernel time
267366 torch ._dynamo .reset ()
268367 m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
368+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , grad_output )
270369
271370 # get the float8 dynamic scaling gpu kernel time
272371
273372 torch ._dynamo .reset ()
274- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
373+ if float8_recipe_name is not None :
374+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
375+ m_fp8_dyn = convert_to_float8_training (
376+ copy .deepcopy (m_orig ), config = config
377+ )
378+ else :
379+ assert mx_recipe_name is not None
380+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
381+ m_fp8_dyn = copy .deepcopy (m_orig )
382+ swap_linear_with_mx_linear (m_fp8_dyn , config = config )
275383 m_fp8_dyn = torch .compile (m_fp8_dyn )
276- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
384+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , grad_output )
277385
278386 results .append (
279387 [
@@ -295,6 +403,9 @@ def run(
295403 b_bf16_e2e_time_s ,
296404 b_fp8_e2e_time_s ,
297405 b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
406+ # gemm ratios
407+ rb_bf16_gemm_ratio ,
408+ rb_fp8_gemm_ratio ,
298409 ]
299410 )
300411
0 commit comments