5858)
5959
6060from torchao .float8 import (
61- Float8LinearConfig ,
6261 convert_to_float8_training ,
6362)
6463from torchao .testing .float8 .roofline_utils import (
6564 get_float8_mem_sympy ,
6665 get_gemm_time_sympy ,
6766)
68- from torchao .utils import is_sm_at_least_90 , is_sm_at_least_100
6967
7068
7169class LNLinearSigmoid (torch .nn .Module ):
@@ -155,21 +153,13 @@ def do_matmul(A, B):
155153
156154 f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
157155
158- if is_sm_at_least_90 () and (not is_sm_at_least_100 ()):
159- scale_a = torch .ones (M , 1 , device = device )
160- scale_b = torch .ones (1 , N , device = device )
161- fast_accum = True # for axiswise
162- f8_axs_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
163- else :
164- f8_axs_time_s = - 1.0
165-
166156 # save to cache if needed
167157 if cache_filename is not None :
168- cache [key ] = [bf16_time_s , f8_time_s , f8_axs_time_s ]
158+ cache [key ] = [bf16_time_s , f8_time_s ]
169159 with open (cache_filename , "w" ) as f :
170160 json .dump (cache , f )
171161
172- return bf16_time_s , f8_time_s , f8_axs_time_s
162+ return bf16_time_s , f8_time_s
173163
174164
175165def run (
@@ -229,18 +219,13 @@ def run(
229219 # gemm microbenchmarks
230220 "bf16_gemm_s" ,
231221 "fp8_gemm_s" ,
232- "fp8_axs_gemm_time_s" ,
233222 # roofline memory overhead estimates
234- "fp8_oh_dyn_limit " ,
235- "fp8_oh_dyn_nolimit " ,
223+ "fp8_oh_estimated " ,
224+ "fp8_oh_ideal " ,
236225 # actual e2e measurements
237226 "bf16_s" ,
238227 "fp8_dyn_s" ,
239- "fp8_dyn_axs_s" ,
240- # 'fp8_lw_s',
241228 "fp8_dyn_sp" ,
242- "fp8_dyn_axs_sp" ,
243- # 'fp8_lw_sp',
244229 ]
245230 results = []
246231
@@ -251,18 +236,17 @@ def run(
251236 break
252237
253238 if gemm_time_strategy == "benchmarks" :
254- bf16_g1 , f8_g1 , f8_g1_axs = get_gemm_times (
239+ bf16_g1 , f8_g1 = get_gemm_times (
255240 M_val , K_val , N_val , True , gemm_cache_filename
256241 )
257- bf16_g2 , f8_g2 , f8_g2_axs = get_gemm_times (
242+ bf16_g2 , f8_g2 = get_gemm_times (
258243 M_val , N_val , K_val , False , gemm_cache_filename
259244 )
260- bf16_g3 , f8_g3 , f8_g3_axs = get_gemm_times (
245+ bf16_g3 , f8_g3 = get_gemm_times (
261246 K_val , M_val , N_val , False , gemm_cache_filename
262247 )
263248 bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
264249 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
265- fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
266250 else :
267251 assert gemm_time_strategy == "roofline" , "unsupported"
268252 bf16_time_val = (
@@ -271,8 +255,6 @@ def run(
271255 fp8_gemm_time_s = (
272256 fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
273257 )
274- # for now, assume axiswise gemm is similar to tensorwise
275- fp8_axs_gemm_time_s = fp8_gemm_time_s
276258
277259 fp8_mem_time_dyn_limit_s = (
278260 fp8_mem_time_sympy_dyn_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
@@ -299,28 +281,6 @@ def run(
299281 m_fp8_dyn = torch .compile (m_fp8_dyn )
300282 fp8_dyn_time_actual_s = get_gpu_kernel_time (m_fp8_dyn , x )
301283
302- # get the float8 dynamic axiswise scaling gpu kernel time, if supported
303- # on current hardware
304- if is_sm_at_least_90 () and (not is_sm_at_least_100 ()):
305- torch ._dynamo .reset ()
306- config = Float8LinearConfig .from_recipe_name ("rowwise" )
307- m_fp8_dyn_axs = convert_to_float8_training (
308- copy .deepcopy (m_orig ), config = config
309- )
310- m_fp8_dyn_axs = torch .compile (m_fp8_dyn_axs )
311- fp8_dyn_axs_time_actual_s = get_gpu_kernel_time (m_fp8_dyn_axs , x )
312- else :
313- fp8_dyn_axs_time_actual_s = - 1.0
314-
315- # get the lw recipe scaling gpu kernel time
316- # TODO(future PR): enable below once basic performance issues
317- # are fixed
318- # torch._dynamo.reset()
319- # config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp")
320- # m_fp8_lw = convert_to_float8_training(m_orig, config=config)
321- # m_fp8_lw = torch.compile(m_fp8_lw)
322- # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
323-
324284 results .append (
325285 [
326286 M_val ,
@@ -329,18 +289,13 @@ def run(
329289 # gemm microbenchmarks
330290 bf16_time_val ,
331291 fp8_gemm_time_s ,
332- fp8_axs_gemm_time_s ,
333292 # roofline overhead estimates
334293 fp8_mem_time_dyn_limit_s ,
335294 fp8_mem_time_dyn_nolimit_s ,
336295 # e2e numbers
337296 bf16_time_actual_s ,
338297 fp8_dyn_time_actual_s ,
339- fp8_dyn_axs_time_actual_s ,
340- # fp8_lw_time_actual_s,
341298 bf16_time_actual_s / fp8_dyn_time_actual_s ,
342- bf16_time_actual_s / fp8_dyn_axs_time_actual_s ,
343- # bf16_time_actual_s / fp8_lw_time_actual_s,
344299 ]
345300 )
346301
0 commit comments