6
6
import argparse
7
7
import subprocess
8
8
9
- TOKENS = 256
10
-
11
9
online_threads = None
12
10
13
11
@@ -22,6 +20,9 @@ def parse_args():
22
20
parser .add_argument ("-p" , "--prompt_size" ,
23
21
type = int , required = True ,
24
22
help = "prompt size to feed the model with" )
23
+ parser .add_argument ("-tg" , "--tg_size" ,
24
+ type = int , default = 256 ,
25
+ help = "output token generated from the model" )
25
26
parser .add_argument ("-r" , "--threads_range" ,
26
27
type = str , required = True ,
27
28
help = "range of threads to use, e.g. '0-63,128-191', threads will be divided between processes "
@@ -38,6 +39,9 @@ def parse_args():
38
39
parser .add_argument ("--mp" ,
39
40
type = str , default = "local" ,
40
41
help = "memory placement policy, 'local','interleave' or 'none'" )
42
+ parser .add_argument ("-fa" ,
43
+ type = int , default = 0 , choices = range (0 ,2 ),
44
+ help = "enable flash attention" )
41
45
return parser .parse_args ()
42
46
43
47
@@ -71,20 +75,20 @@ def summarize_results(logs_dir, args, start, finish):
71
75
prompt_size = int (results [1 ])
72
76
assert prompt_size == args .prompt_size
73
77
tokens_generated = int (results [2 ])
74
- assert tokens_generated == TOKENS
78
+ assert tokens_generated == args . tg_size
75
79
batch_size = int (results [3 ])
76
80
assert batch_size == args .batch_size
77
81
ttfts .append (float (results [5 ]))
78
82
tg_lats .append (float (results [7 ]))
79
83
80
84
pp_throughput = sum ([args .batch_size * args .prompt_size / ttft for ttft in ttfts ])
81
85
avg_pp_latency = sum (ttfts ) / len (ttfts )
82
- tg_throughput = sum ([args .batch_size * TOKENS / lat for lat in tg_lats ])
83
- tg_per_token_lats = [lat / TOKENS for lat in tg_lats ]
86
+ tg_throughput = sum ([args .batch_size * args . tg_size / lat for lat in tg_lats ])
87
+ tg_per_token_lats = [lat / args . tg_size for lat in tg_lats ]
84
88
avg_tg_latency = sum (tg_per_token_lats ) / len (tg_per_token_lats )
85
- avg_total_speed = args .num_processes * args .batch_size * (args .prompt_size + TOKENS ) / max ([ttft + tg_lat for ttft , tg_lat in zip (ttfts , tg_lats )])
89
+ avg_total_speed = args .num_processes * args .batch_size * (args .prompt_size + args . tg_size ) / max ([ttft + tg_lat for ttft , tg_lat in zip (ttfts , tg_lats )])
86
90
87
- results_filename = f"{ args .model .split ('/' )[- 1 ]} @PP{ str (args .prompt_size )} @TG{ str (TOKENS )} .csv"
91
+ results_filename = f"{ args .model .split ('/' )[- 1 ]} @PP{ str (args .prompt_size )} @TG{ str (args . tg_size ) } @fa { str ( args . fa ) } @ctx { str ( args . kv_cache )} .csv"
88
92
if os .path .exists (results_filename ):
89
93
first_write = False
90
94
else :
@@ -96,7 +100,7 @@ def summarize_results(logs_dir, args, start, finish):
96
100
["n_proc" , "n_threads" , "batch_size" , "prompt_size" , "output_tokens" , "pp_throughput_tps" ,
97
101
"pp_avg_latency_sec" , "tg_throughput_tps" , "tg_avg_latency_sec" , "pp+tg_throughput_tps" , "concurrency" , "start" , "finish" ])
98
102
writer .writerow (
99
- [args .num_processes , args .num_threads , args .batch_size , args .prompt_size , TOKENS , f"{ pp_throughput :.3f} " ,
103
+ [args .num_processes , args .num_threads , args .batch_size , args .prompt_size , args . tg_size , f"{ pp_throughput :.3f} " ,
100
104
f"{ avg_pp_latency :.3f} " , f"{ tg_throughput :.3f} " , f"{ avg_tg_latency :.3f} " , f"{ avg_total_speed :.3f} " , args .batch_size * args .num_processes , f"{ start :.3f} " , f"{ finish :.3f} " ])
101
105
102
106
print (f"Result saved in { results_filename } " )
@@ -131,23 +135,26 @@ def main():
131
135
# command-line for v1
132
136
if mem_place == "none" :
133
137
cmd = ["numactl" , f"--physcpubind={ gen_threads_config (args .num_threads , n )} " ,
134
- "/llm/batched-bench" , args .model , str (args .kv_cache ), "2048" , "512" , "0" , "0" , "0" , str (args .prompt_size ), str (TOKENS ),
138
+ "/llm/batched-bench" , args .model , str (args .kv_cache ), "2048" , "512" , "0" , "0" , "0" , str (args .prompt_size ), str (args . tg_size ),
135
139
str (args .batch_size ), str (args .num_threads )]
136
140
else :
137
141
cmd = ["numactl" , f"--physcpubind={ gen_threads_config (args .num_threads , n )} " , str (mem_place ),
138
- "/llm/batched-bench" , args .model , str (args .kv_cache ), "2048" , "512" , "0" , "0" , "0" , str (args .prompt_size ), str (TOKENS ),
142
+ "/llm/batched-bench" , args .model , str (args .kv_cache ), "2048" , "512" , "0" , "0" , "0" , str (args .prompt_size ), str (args . tg_size ),
139
143
str (args .batch_size ), str (args .num_threads )]
140
144
elif os .path .exists ("/llm/llama-batched-bench" ):
141
145
# command-line for v2
142
146
if mem_place == "none" :
143
147
cmd = ["numactl" , f"--physcpubind={ gen_threads_config (args .num_threads , n )} " ,
144
- "/llm/llama-batched-bench" , "-m" , args .model , "-c" , str (args .kv_cache ), "-b" , "2048" , "-ub" , "512" , "-npp" , str (args .prompt_size ), "-ntg" , str (TOKENS ),
148
+ "/llm/llama-batched-bench" , "-m" , args .model , "-c" , str (args .kv_cache ), "-b" , "2048" , "-ub" , "512" , "-npp" , str (args .prompt_size ), "-ntg" , str (args . tg_size ),
145
149
"-npl" , str (args .batch_size ), "-t" , str (args .num_threads ), "-tb" , str (args .num_threads ), "--no-mmap" ]
146
150
else :
147
151
cmd = ["numactl" , f"--physcpubind={ gen_threads_config (args .num_threads , n )} " ,str (mem_place ),
148
- "/llm/llama-batched-bench" , "-m" , args .model , "-c" , str (args .kv_cache ), "-b" , "2048" , "-ub" , "512" , "-npp" , str (args .prompt_size ), "-ntg" , str (TOKENS ),
152
+ "/llm/llama-batched-bench" , "-m" , args .model , "-c" , str (args .kv_cache ), "-b" , "2048" , "-ub" , "512" , "-npp" , str (args .prompt_size ), "-ntg" , str (args . tg_size ),
149
153
"-npl" , str (args .batch_size ), "-t" , str (args .num_threads ), "-tb" , str (args .num_threads ), "--no-mmap" ]
150
154
155
+ if args .fa != 0 :
156
+ cmd .append ("--flash-attn" )
157
+
151
158
else :
152
159
print ("FAIL: batched-bench not found!" )
153
160
sys .exit (1 )
0 commit comments