99BYTES_PER_EL_FLOAT8 = 1
1010BYTES_PER_EL_BF16 = 2
1111
12- # https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
13- H100_BF16_PEAK_TOPS = 989e12
14- H100_FP8_PEAK_TOPS = 1979e12
12+ gpu_name_to_specs = {
13+ "NVIDIA H100" : {
14+ # https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
15+ "bf16_peak_tops" : 989e12 ,
16+ "fp8_peak_tops" : 1979e12 ,
17+ # 2.4 TB per second, custom to Meta's H100 variant
18+ "peak_mem_bw_bytes_sec" : 2.4e12 ,
19+ # based on quick experimental observation with sample large inputs
20+ "pct_achievable_gemm_tops" : 0.6 ,
21+ # based on previous experience looking at pointwise triton kernels with large inputs,
22+ # which would hit about 2.2k GBPS on Meta's H100 variant
23+ "pct_achievable_mem_bw" : 0.92 ,
24+ },
25+ "NVIDIA B200" : {
26+ # https://resources.nvidia.com/en-us-blackwell-architecture, page 19,
27+ # divide by 2 because no sparsity
28+ "bf16_peak_tops" : 2.25e15 ,
29+ "fp8_peak_tops" : 4.5e15 ,
30+ "fp4_peak_tops" : 9.0e15 ,
31+ # https://resources.nvidia.com/en-us-blackwell-architecture, page 20
32+ # 8.0 TB per second
33+ "peak_mem_bw_bytes_sec" : 8.0e12 ,
34+ # for now, copy over from H100
35+ # TODO(future): measure once we have the hardware
36+ "pct_achievable_gemm_tops" : 0.6 ,
37+ # for now, copy over from H100
38+ # TODO(future): measure once we have the hardware
39+ "pct_achievable_mem_bw" : 0.92 ,
40+ },
41+ # TODO(future): more GPU names
42+ }
43+
44+
45+ def get_specs ():
46+ gpu_name = torch .cuda .get_device_name (0 )
47+ return gpu_name_to_specs [gpu_name ]
1548
16- # 2.4 TB per second, custom to Meta's H100 variant
17- H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12
18-
19- # based on quick experimental observation with sample large inputs
20- H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6
21-
22- # based on previous experience looking at pointwise triton kernels with large inputs,
23- # which would hit about 2.2k GBPS on Meta's H100 variant
24- H100_PCT_ACHIEVABLE_MEM_BW = 0.92
2549
2650# Source: run a triton kernel with a single element read/write on an H100 and
2751# measure GPU time from the trace
@@ -65,12 +89,13 @@ def get_tensor_memory_traffic_bytes(
6589
6690
6791def get_gemm_time_sympy (M , K , N , dtype ):
92+ specs = get_specs ()
6893 gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
6994 if dtype is torch .bfloat16 :
70- peak_tops = H100_BF16_PEAK_TOPS
95+ peak_tops = specs [ "bf16_peak_tops" ]
7196 elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
72- peak_tops = H100_FP8_PEAK_TOPS
73- gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
97+ peak_tops = specs [ "fp8_peak_tops" ]
98+ gemm_time_s = gemm_ops / peak_tops / specs [ "pct_achievable_gemm_tops" ]
7499 return gemm_time_s
75100
76101
@@ -87,6 +112,8 @@ def get_float8_mem_sympy(
87112 assert scaling_type_weight in ("dynamic" ,), "unsupported"
88113 assert scaling_type_grad_output in ("dynamic" ,), "unsupported"
89114
115+ specs = get_specs ()
116+
90117 # there are three gemms in the fwd/bwd of a linear:
91118 #
92119 # input @ weight_t = output
@@ -148,7 +175,7 @@ def get_float8_mem_sympy(
148175 )
149176 fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
150177 fp8_mem_time_s = (
151- fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
178+ fp8_total_mem / specs [ "peak_mem_bw_bytes_sec" ] / specs [ "pct_achievable_mem_bw" ]
152179 )
153180
154181 # Adjust final estimate for small kernel launches
0 commit comments