Skip to content

Commit ecbcf6e

Browse files
Add -fa option (#14)
* Add -fa option * limit the accept value for fa to 0,1 --------- Co-authored-by: davidz-ampere <>
1 parent aa0a5d7 commit ecbcf6e

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

benchmarks/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ Provide run.py Python script with following arguments:
2727
- -b, batch size(s) to benchmark, meaning separate token generation streams handled as a single batch; multiple batch sizes can be provided and they will be treated as separate cases to benchmark
2828
- -p, prompt size(s) to benchmark, size of an input prompt; multiple prompt sizes can be provided and they will be treated as separate cases to benchmark
2929
- -r, thread-range, e.g., on an 80-thread system, it should be input as 0-79, unless user wants to use just a subset of available threads, say 16-63 (48 threads indexed 16<>63)
30+
- -fa, 0/1, disable/enable flash attention, default: 0
3031
```bash
31-
python3 run.py -m Meta-Llama-3-8B-Instruct.Q8_0.gguf -t 10 16 32 40 64 80 -b 1 2 4 8 16 32 64 -p 512 -r 0-79
32+
python3 run.py -m Meta-Llama-3-8B-Instruct.Q8_0.gguf -t 10 16 32 40 64 80 -b 1 2 4 8 16 32 64 -p 512 -r 0-79 -fa 1
3233
```
3334

3435
## Quick run on 80t OCI A1 system

benchmarks/run.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def get_file_dir():
1212

1313

1414
def docker_init(node):
15-
tag = "amperecomputingai/llama.cpp:2.2.1"
15+
tag = "amperecomputingai/llama.cpp:3.1.2"
1616
if subprocess.run(
1717
["docker", "pull", tag]).returncode != 0:
1818
print("Docker pull process failed!")
@@ -66,6 +66,10 @@ def benchmark(docker_container_name, args):
6666

6767
cmd = (f"cd /runner; python3 utils/benchmark.py -m models/{model} -n {str(num_processes)} "
6868
f"-t {str(num_threads)} -b {str(batch_size)} -p {str(prompt_size)} -r {args.threads_range}")
69+
70+
if args.fa != 0 :
71+
cmd += " -fa 1"
72+
6973
cmd = ["docker", "exec", "-i", docker_container_name, "bash", "-c", cmd]
7074

7175
print(f"Executing: {' '.join(cmd)}")
@@ -109,6 +113,9 @@ def parse_args():
109113
parser.add_argument("-n", "--numa",
110114
type=int, default=0,
111115
help="numa mode of the docker container")
116+
parser.add_argument("-fa",
117+
type=int, default=0, choices=range(0, 2),
118+
help="enable flash attention")
112119

113120
return parser.parse_args()
114121

benchmarks/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
set -e
22

3-
python3 run.py -m Meta-Llama-3-8B-Instruct.Q8_0.gguf -t 10 16 32 40 64 80 -b 1 2 4 8 16 32 64 -p 512 -r 0-79
3+
python3 run.py -m Meta-Llama-3-8B-Instruct.Q8_0.gguf -t 10 16 32 40 64 80 -b 1 2 4 8 16 32 64 -p 512 -r 0-79 -fa 1
44
rm -f /tmp/log_power

benchmarks/utils/benchmark.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import argparse
77
import subprocess
88

9-
TOKENS = 256
10-
119
online_threads = None
1210

1311

@@ -22,6 +20,9 @@ def parse_args():
2220
parser.add_argument("-p", "--prompt_size",
2321
type=int, required=True,
2422
help="prompt size to feed the model with")
23+
parser.add_argument("-tg", "--tg_size",
24+
type=int, default=256,
25+
help="output token generated from the model")
2526
parser.add_argument("-r", "--threads_range",
2627
type=str, required=True,
2728
help="range of threads to use, e.g. '0-63,128-191', threads will be divided between processes "
@@ -38,6 +39,9 @@ def parse_args():
3839
parser.add_argument("--mp",
3940
type=str, default="local",
4041
help="memory placement policy, 'local','interleave' or 'none'")
42+
parser.add_argument("-fa",
43+
type=int, default=0, choices=range(0,2),
44+
help="enable flash attention")
4145
return parser.parse_args()
4246

4347

@@ -71,20 +75,20 @@ def summarize_results(logs_dir, args, start, finish):
7175
prompt_size = int(results[1])
7276
assert prompt_size == args.prompt_size
7377
tokens_generated = int(results[2])
74-
assert tokens_generated == TOKENS
78+
assert tokens_generated == args.tg_size
7579
batch_size = int(results[3])
7680
assert batch_size == args.batch_size
7781
ttfts.append(float(results[5]))
7882
tg_lats.append(float(results[7]))
7983

8084
pp_throughput = sum([args.batch_size * args.prompt_size / ttft for ttft in ttfts])
8185
avg_pp_latency = sum(ttfts) / len(ttfts)
82-
tg_throughput = sum([args.batch_size * TOKENS / lat for lat in tg_lats])
83-
tg_per_token_lats = [lat / TOKENS for lat in tg_lats]
86+
tg_throughput = sum([args.batch_size * args.tg_size / lat for lat in tg_lats])
87+
tg_per_token_lats = [lat / args.tg_size for lat in tg_lats]
8488
avg_tg_latency = sum(tg_per_token_lats) / len(tg_per_token_lats)
85-
avg_total_speed = args.num_processes * args.batch_size * (args.prompt_size + TOKENS) / max([ttft + tg_lat for ttft, tg_lat in zip(ttfts, tg_lats)])
89+
avg_total_speed = args.num_processes * args.batch_size * (args.prompt_size + args.tg_size) / max([ttft + tg_lat for ttft, tg_lat in zip(ttfts, tg_lats)])
8690

87-
results_filename = f"{args.model.split('/')[-1]}@PP{str(args.prompt_size)}@TG{str(TOKENS)}.csv"
91+
results_filename = f"{args.model.split('/')[-1]}@PP{str(args.prompt_size)}@TG{str(args.tg_size)}@fa{str(args.fa)}@ctx{str(args.kv_cache)}.csv"
8892
if os.path.exists(results_filename):
8993
first_write = False
9094
else:
@@ -96,7 +100,7 @@ def summarize_results(logs_dir, args, start, finish):
96100
["n_proc", "n_threads", "batch_size", "prompt_size", "output_tokens", "pp_throughput_tps",
97101
"pp_avg_latency_sec", "tg_throughput_tps", "tg_avg_latency_sec", "pp+tg_throughput_tps", "concurrency", "start", "finish"])
98102
writer.writerow(
99-
[args.num_processes, args.num_threads, args.batch_size, args.prompt_size, TOKENS, f"{pp_throughput:.3f}",
103+
[args.num_processes, args.num_threads, args.batch_size, args.prompt_size, args.tg_size, f"{pp_throughput:.3f}",
100104
f"{avg_pp_latency:.3f}", f"{tg_throughput:.3f}", f"{avg_tg_latency:.3f}", f"{avg_total_speed:.3f}", args.batch_size * args.num_processes, f"{start:.3f}", f"{finish:.3f}"])
101105

102106
print(f"Result saved in {results_filename}")
@@ -131,23 +135,26 @@ def main():
131135
# command-line for v1
132136
if mem_place == "none":
133137
cmd = ["numactl", f"--physcpubind={gen_threads_config(args.num_threads, n)}",
134-
"/llm/batched-bench", args.model, str(args.kv_cache), "2048", "512", "0", "0", "0", str(args.prompt_size), str(TOKENS),
138+
"/llm/batched-bench", args.model, str(args.kv_cache), "2048", "512", "0", "0", "0", str(args.prompt_size), str(args.tg_size),
135139
str(args.batch_size), str(args.num_threads)]
136140
else:
137141
cmd = ["numactl", f"--physcpubind={gen_threads_config(args.num_threads, n)}", str(mem_place),
138-
"/llm/batched-bench", args.model, str(args.kv_cache), "2048", "512", "0", "0", "0", str(args.prompt_size), str(TOKENS),
142+
"/llm/batched-bench", args.model, str(args.kv_cache), "2048", "512", "0", "0", "0", str(args.prompt_size), str(args.tg_size),
139143
str(args.batch_size), str(args.num_threads)]
140144
elif os.path.exists("/llm/llama-batched-bench"):
141145
# command-line for v2
142146
if mem_place == "none":
143147
cmd = ["numactl", f"--physcpubind={gen_threads_config(args.num_threads, n)}",
144-
"/llm/llama-batched-bench", "-m", args.model, "-c", str(args.kv_cache), "-b", "2048", "-ub", "512", "-npp", str(args.prompt_size), "-ntg", str(TOKENS),
148+
"/llm/llama-batched-bench", "-m", args.model, "-c", str(args.kv_cache), "-b", "2048", "-ub", "512", "-npp", str(args.prompt_size), "-ntg", str(args.tg_size),
145149
"-npl", str(args.batch_size), "-t", str(args.num_threads), "-tb", str(args.num_threads), "--no-mmap"]
146150
else:
147151
cmd = ["numactl", f"--physcpubind={gen_threads_config(args.num_threads, n)}",str(mem_place),
148-
"/llm/llama-batched-bench", "-m", args.model, "-c", str(args.kv_cache), "-b", "2048", "-ub", "512", "-npp", str(args.prompt_size), "-ntg", str(TOKENS),
152+
"/llm/llama-batched-bench", "-m", args.model, "-c", str(args.kv_cache), "-b", "2048", "-ub", "512", "-npp", str(args.prompt_size), "-ntg", str(args.tg_size),
149153
"-npl", str(args.batch_size), "-t", str(args.num_threads), "-tb", str(args.num_threads), "--no-mmap"]
150154

155+
if args.fa != 0 :
156+
cmd.append("--flash-attn")
157+
151158
else:
152159
print("FAIL: batched-bench not found!")
153160
sys.exit(1)

0 commit comments

Comments
 (0)