4747import pandas as pd
4848import sympy
4949import torch
50+ import torch .nn as nn
5051import torch .utils .benchmark as benchmark
5152import tqdm
5253from torch .profiler import ProfilerActivity , profile
5758)
5859
5960from torchao .float8 import (
61+ Float8LinearConfig ,
6062 convert_to_float8_training ,
6163)
64+ from torchao .prototype .mx_formats .config import MXLinearConfig
65+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
6266from torchao .testing .float8 .roofline_utils import (
6367 get_float8_mem_sympy ,
6468 get_gemm_time_sympy ,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397 return measurement .mean
9498
9599
96- def get_gpu_kernel_time (m , x ):
100+ def get_gpu_kernel_time (m , x , grad_output ):
97101 # warm up
98102 for _ in range (2 ):
99- m (x ).sum ().backward ()
103+ y = m (x )
104+ y .backward (grad_output )
100105
101106 # capture a profiling run
102107 activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
103108 n_iter = 5
104109 with profile (activities = activities ) as prof :
105110 for _ in range (n_iter ):
106- m (x ).sum ().backward ()
111+ y = m (x )
112+ y .backward (grad_output )
107113 torch .cuda .synchronize ()
108114 # get the gpu kernel time and aggregate it
109115 num_leaf_tensors = 1 + len (list (m .parameters ()))
@@ -114,10 +120,28 @@ def get_gpu_kernel_time(m, x):
114120 return total_time_s
115121
116122
117- def get_gemm_times (M , K , N , fast_accum , cache_filename = None ):
123+ def get_gemm_times (
124+ gemm_role : str ,
125+ M : int ,
126+ K : int ,
127+ N : int ,
128+ fast_accum : bool ,
129+ bf16_memory_formats : str ,
130+ float8_recipe_name : Optional [str ],
131+ mx_recipe_name : Optional [str ],
132+ cache_filename = None ,
133+ ):
134+ assert gemm_role in ("output" , "grad_input" , "grad_weight" ), "unsupported"
135+ assert bf16_memory_formats in (
136+ "row_major:col_major" ,
137+ "row_major:row_major" ,
138+ "col_major:row_major" ,
139+ ), "unsupported"
140+
118141 # Note: this is definitely not the best way to build a cache,
119142 # but it will do for now.
120143 if cache_filename is not None :
144+ assert False , "TODO retest this for new arguments"
121145 if os .path .isfile (cache_filename ):
122146 # cache already exists, use it
123147 with open (cache_filename , "r" ) as f :
@@ -127,30 +151,48 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127151 cache = dict ()
128152 else :
129153 cache = dict ()
130- key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
154+ key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } , { bf16_memory_formats } "
131155 if key in cache :
132156 return cache [key ]
133157
134158 device = torch .device ("cuda" )
135159
136160 # bf16 time
137161 x_bf16 = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
138- w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device ).t ().contiguous ().t ()
162+ # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
163+ w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device )
164+
165+ if bf16_memory_formats == "row_major:col_major" :
166+ w_bf16 = w_bf16 .t ().contiguous ().t ()
167+ elif bf16_memory_formats == "col_major:row_major" :
168+ x_bf16 = x_bf16 .t ().contiguous ().t ()
169+ elif bf16_memory_formats == "col_major:row_major" :
170+ x_bf16 = x_bf16 .t ().contiguous ().t ()
171+
139172 bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
140173
141174 # f8 time
142- d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
143- A = torch .zeros (M , K , device = device , dtype = d1 )
144- B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
145- scale_a = torch .tensor ([1.0 ], device = device )
146- scale_b = torch .tensor ([1.0 ], device = device )
147-
148- def do_matmul (A , B ):
149- return torch ._scaled_mm (
150- A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
151- )
175+ if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight" :
176+ f8_time_s = bf16_time_s
177+ else :
178+ d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
179+ A = torch .zeros (M , K , device = device , dtype = d1 )
180+ B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
181+ if float8_recipe_name == "tensorwise" :
182+ scale_a = torch .tensor ([1.0 ], device = device )
183+ scale_b = torch .tensor ([1.0 ], device = device )
184+ elif float8_recipe_name in ("rowwise" , "rowwise_with_gw_hp" ):
185+ scale_a = torch .ones (M , 1 , device = device )
186+ scale_b = torch .ones (1 , N , device = device )
187+ else :
188+ assert False , "TODO add mx gemm here"
189+
190+ def do_matmul (A , B ):
191+ return torch ._scaled_mm (
192+ A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
193+ )
152194
153- f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
195+ f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
154196
155197 # save to cache if needed
156198 if cache_filename is not None :
@@ -164,33 +206,52 @@ def do_matmul(A, B):
164206def run (
165207 outfile : str ,
166208 do_benchmarks : bool = True ,
167- shape_gen_name : str = "square " ,
209+ shape_gen_name : str = "pow2 " ,
168210 gemm_cache_filename : Optional [str ] = None ,
169211 n_limit : Optional [int ] = None ,
212+ float8_recipe_name : Optional [str ] = None ,
213+ mx_recipe_name : Optional [str ] = None ,
214+ enable_fusion_modeling : bool = False ,
170215):
171216 """
172217 Args:
173218 * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174- * `shape_gen_name`: `llama`, `square `, or `sweep`
219+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended `, or `sweep`
175220 * `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176221 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
222+ * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177223 """
178224
225+ assert not (
226+ (float8_recipe_name is not None ) and (mx_recipe_name is not None )
227+ ), "unsupported"
228+ if float8_recipe_name is None and mx_recipe_name is None :
229+ float8_recipe_name = "tensorwise"
230+
231+ print (f"GPU: { torch .cuda .get_device_name (0 )} " )
179232 print (f"do_benchmarks: { do_benchmarks } " )
180233 print (f"shape_gen_name: { shape_gen_name } " )
234+ print (f"float8_recipe_name: { float8_recipe_name } " )
235+ print (f"mx_recipe_name: { mx_recipe_name } " )
236+ print (f"enable_fusion_modeling: { enable_fusion_modeling } " )
181237
182238 M , K , N = sympy .symbols ("M K N" )
183239
184- fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
240+ fp8_ovhd_time_sympy = get_float8_mem_sympy (
185241 M ,
186242 K ,
187243 N ,
244+ float8_recipe_name ,
245+ mx_recipe_name ,
246+ enable_fusion_modeling ,
247+ )
248+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None , None )
249+ fp8_gemm_time_sympy = get_gemm_time_sympy (
250+ M , K , N , torch .float8_e4m3fn , float8_recipe_name , mx_recipe_name
188251 )
189-
190- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191252 print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193253 print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
254+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
194255 print ()
195256
196257 headers = [
@@ -217,6 +278,9 @@ def run(
217278 # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218279 # we don't break them out and don't have a roofline for them.
219280 "b_fp8_e2e_spdp" ,
281+ # how well benchmarked gemms match roofline predicted gemms
282+ "rb_bf16_gemm_ratio" ,
283+ "rb_fp8_gemm_ratio" ,
220284 ]
221285 results = []
222286
@@ -237,43 +301,96 @@ def run(
237301
238302 # if enabled, also measured observed gemm time
239303 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
304+ rb_bf16_gemm_ratio = - 1
305+ rb_fp8_gemm_ratio = - 1
306+
240307 if do_benchmarks :
308+ # TODO(future): make the bf16 gemm times exactly match the e2e
309+ # benchmarks, there is a slight deviation, probably related to gemm
310+ # operand memory formats/transpositions below not exactly matching
311+ # what PyTorch core is doing for `torch.mm`
312+ # input @ weight_t = output
241313 bf16_g1 , f8_g1 = get_gemm_times (
242- M_val , K_val , N_val , True , gemm_cache_filename
314+ "output" ,
315+ M_val ,
316+ K_val ,
317+ N_val ,
318+ True ,
319+ "row_major:col_major" ,
320+ float8_recipe_name ,
321+ mx_recipe_name ,
322+ gemm_cache_filename ,
243323 )
324+ # grad_output @ weight = grad_input
244325 bf16_g2 , f8_g2 = get_gemm_times (
245- M_val , N_val , K_val , False , gemm_cache_filename
326+ "grad_input" ,
327+ M_val ,
328+ N_val ,
329+ K_val ,
330+ False ,
331+ "row_major:row_major" ,
332+ float8_recipe_name ,
333+ mx_recipe_name ,
334+ gemm_cache_filename ,
246335 )
336+ # input_t @ grad_output = grad_weight
247337 bf16_g3 , f8_g3 = get_gemm_times (
248- K_val , M_val , N_val , False , gemm_cache_filename
338+ "grad_weight" ,
339+ K_val ,
340+ M_val ,
341+ N_val ,
342+ False ,
343+ "col_major:row_major" ,
344+ float8_recipe_name ,
345+ mx_recipe_name ,
346+ gemm_cache_filename ,
249347 )
250348 b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251349 b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
350+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
351+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252352
253353 # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254354 r_fp8_ovhd_time_s = float (
255- fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
355+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
256356 )
257357
258358 b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259359 if do_benchmarks :
260360 # create the model
261- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
361+ if enable_fusion_modeling :
362+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
363+ else :
364+ m_orig = (
365+ nn .Sequential (nn .Linear (K_val , N_val , bias = False )).cuda ().bfloat16 ()
366+ )
262367 x = torch .randn (
263368 M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264369 ).requires_grad_ ()
265370
371+ # get the gradient of the right shape
372+ grad_output = torch .randn (N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
373+
266374 # get the bf16 gpu kernel time
267375 torch ._dynamo .reset ()
268376 m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
377+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , grad_output )
270378
271379 # get the float8 dynamic scaling gpu kernel time
272380
273381 torch ._dynamo .reset ()
274- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
382+ if float8_recipe_name is not None :
383+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
384+ m_fp8_dyn = convert_to_float8_training (
385+ copy .deepcopy (m_orig ), config = config
386+ )
387+ else :
388+ assert mx_recipe_name is not None
389+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
390+ m_fp8_dyn = copy .deepcopy (m_orig )
391+ swap_linear_with_mx_linear (m_fp8_dyn , config = config )
275392 m_fp8_dyn = torch .compile (m_fp8_dyn )
276- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
393+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , grad_output )
277394
278395 results .append (
279396 [
@@ -295,6 +412,9 @@ def run(
295412 b_bf16_e2e_time_s ,
296413 b_fp8_e2e_time_s ,
297414 b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
415+ # gemm ratios
416+ rb_bf16_gemm_ratio ,
417+ rb_fp8_gemm_ratio ,
298418 ]
299419 )
300420
0 commit comments