44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from typing import Optional
8+
79import torch
810
911BYTES_PER_EL_FLOAT8 = 1
@@ -55,29 +57,67 @@ def get_specs():
5557def get_tensor_memory_traffic_bytes (
5658 dim0 ,
5759 dim1 ,
60+ float8_recipe_name : Optional [str ],
61+ mx_recipe_name : Optional [str ],
5862 fuse_with_prev = False ,
5963):
6064 # assumes input bf16, output f8
6165 numel = dim0 * dim1
6266
63- # x_bf16 = ...
64- # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
65- # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
66- # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
67+ if float8_recipe_name == "tensorwise" :
68+ # x_bf16 = ...
69+ # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
70+ # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
71+ # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
72+
73+ if fuse_with_prev :
74+ kernel_1_rw = 0
75+ else :
76+ # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
77+ kernel_1_rw = BYTES_PER_EL_BF16 * numel
78+
79+ # kernel 3: read in bf16, write twice in float8 (row-major and col-major)
80+ kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
81+
82+ return kernel_1_rw + kernel_3_rw
83+
84+ elif float8_recipe_name == "rowwise" :
85+ # x_bf16 = ...
86+ # kernel 1: x_bf16 -> x_float8_dim0
87+ # kernel 2: x_bf16 -> x_float8_dim1
88+
89+ # assume that we can't fuse 1 and 2 because that would require loading
90+ # the entire tensor to shared memory
91+
92+ if fuse_with_prev :
93+ # assume we can fuse one of the reads with previous op
94+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
95+ else :
96+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
97+
98+ kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
99+
100+ return kernel_1_rw + kernel_2_rw
67101
68- if fuse_with_prev :
69- kernel_1_rw = 0
70102 else :
71- # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
72- kernel_1_rw = BYTES_PER_EL_BF16 * numel
103+ assert mx_recipe_name in ("mxfp8_emulated" , "mxfp8_cutlass" ), "unsupported"
73104
74- # kernel 3: read in bf16, write twice in float8 (row-major and col-major)
75- kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
105+ # x_bf16 = ...
106+ # kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1
76107
77- return kernel_1_rw + kernel_3_rw
108+ if fuse_with_prev :
109+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2
110+ else :
111+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2
112+
113+ return kernel_1_rw
78114
79115
80116def get_gemm_time_sympy (M , K , N , dtype ):
117+ # currently this assumes gemm is compute bound
118+ # TODO(future): maybe make more accurate for small shapes by taking max of
119+ # time to read/write and time to do the dot product, this might also
120+ # slightly differ for MX since scales are larger
81121 specs = get_specs ()
82122 gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
83123 if dtype is torch .bfloat16 :
@@ -89,9 +129,7 @@ def get_gemm_time_sympy(M, K, N, dtype):
89129
90130
91131def get_float8_mem_sympy (
92- M ,
93- K ,
94- N ,
132+ M , K , N , float8_recipe_name : Optional [str ], mx_recipe_name : Optional [str ]
95133):
96134 specs = get_specs ()
97135
@@ -112,11 +150,15 @@ def get_float8_mem_sympy(
112150 fwd_fp8_input_mem = get_tensor_memory_traffic_bytes (
113151 M ,
114152 K ,
153+ float8_recipe_name ,
154+ mx_recipe_name ,
115155 fuse_with_prev = True ,
116156 )
117157 fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes (
118158 K ,
119159 N ,
160+ float8_recipe_name ,
161+ mx_recipe_name ,
120162 fuse_with_prev = False ,
121163 )
122164 fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
@@ -127,6 +169,8 @@ def get_float8_mem_sympy(
127169 gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes (
128170 M ,
129171 N ,
172+ float8_recipe_name ,
173+ mx_recipe_name ,
130174 fuse_with_prev = True ,
131175 )
132176 # already casted, assuming that we save weight from fw to bw
@@ -158,12 +202,20 @@ def get_float8_mem_sympy(
158202 # kernel overhead in the units of seconds, and the per-gemm-input memory
159203 # estimations are in the units of bytes.
160204 num_extra_kernels = 0
161- # second stage of max-abs reduction for input
162- num_extra_kernels += 1
163- # second stage of max-abs reduction for weight
164- num_extra_kernels += 1
165- # second stage of max-abs reduction for grad_output
166- num_extra_kernels += 1
205+ if float8_recipe_name == "tensorwise" :
206+ # second stage of max-abs reduction for input
207+ num_extra_kernels += 1
208+ # second stage of max-abs reduction for weight
209+ num_extra_kernels += 1
210+ # second stage of max-abs reduction for grad_output
211+ num_extra_kernels += 1
212+ elif float8_recipe_name == "rowwise" :
213+ # for simplicity, assume all rowwise kernels are large and bandwidth bound
214+ pass
215+ else :
216+ assert mx_recipe_name in ("mxfp8_emulated" , "mxfp8_cutlass" ), "unsupported"
217+ # for simplicity, assume all mxfp8 kernels are large and bandwidth bound
218+ pass
167219
168220 extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
169221
0 commit comments