@@ -55,14 +55,12 @@ def get_specs():
5555def get_tensor_memory_traffic_bytes (
5656 dim0 ,
5757 dim1 ,
58- scaling_type : str ,
5958 fuse_with_prev = False ,
6059 model_torch_compile_limitations = False ,
6160):
6261 # assumes input bf16, output f8
6362 numel = dim0 * dim1
6463
65- assert scaling_type == "dynamic" , "unsupported"
6664 # x_bf16 = ...
6765 # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
6866 # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
@@ -104,14 +102,7 @@ def get_float8_mem_sympy(
104102 K ,
105103 N ,
106104 model_torch_compile_limitations : bool = False ,
107- scaling_type_input : str = "dynamic" ,
108- scaling_type_weight : str = "dynamic" ,
109- scaling_type_grad_output : str = "dynamic" ,
110105):
111- assert scaling_type_input in ("dynamic" ,), "unsupported"
112- assert scaling_type_weight in ("dynamic" ,), "unsupported"
113- assert scaling_type_grad_output in ("dynamic" ,), "unsupported"
114-
115106 specs = get_specs ()
116107
117108 # there are three gemms in the fwd/bwd of a linear:
@@ -131,14 +122,12 @@ def get_float8_mem_sympy(
131122 fwd_fp8_input_mem = get_tensor_memory_traffic_bytes (
132123 M ,
133124 K ,
134- scaling_type_input ,
135125 fuse_with_prev = True ,
136126 model_torch_compile_limitations = model_torch_compile_limitations ,
137127 )
138128 fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes (
139129 K ,
140130 N ,
141- scaling_type_weight ,
142131 fuse_with_prev = False ,
143132 model_torch_compile_limitations = model_torch_compile_limitations ,
144133 )
@@ -150,7 +139,6 @@ def get_float8_mem_sympy(
150139 gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes (
151140 M ,
152141 N ,
153- scaling_type_grad_output ,
154142 fuse_with_prev = True ,
155143 model_torch_compile_limitations = model_torch_compile_limitations ,
156144 )
@@ -183,15 +171,12 @@ def get_float8_mem_sympy(
183171 # kernel overhead in the units of seconds, and the per-gemm-input memory
184172 # estimations are in the units of bytes.
185173 num_extra_kernels = 0
186- if scaling_type_input == "dynamic" :
187- # second stage of max-abs reduction
188- num_extra_kernels += 1
189- if scaling_type_weight == "dynamic" :
190- # second stage of max-abs reduction
191- num_extra_kernels += 1
192- if scaling_type_grad_output == "dynamic" :
193- # second stage of max-abs reduction
194- num_extra_kernels += 1
174+ # second stage of max-abs reduction for input
175+ num_extra_kernels += 1
176+ # second stage of max-abs reduction for weight
177+ num_extra_kernels += 1
178+ # second stage of max-abs reduction for grad_output
179+ num_extra_kernels += 1
195180
196181 extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
197182
0 commit comments