Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8eb0561

Browse files
authored
Make BERT benchmark code more robust (#1871)
1 parent 255f4f7 commit 8eb0561

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

benchmark/benchmark_bert_tokenizer.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,36 @@ def benchmark_bert_tokenizer(args):
1414
tt_tokenizer = tt_bert_tokenizer(VOCAB_FILE, return_tokens=True)
1515
hf_tokenizer_slow = hf_bert_tokenizer_slow.from_pretrained("bert-base-uncased")
1616
hf_tokenizer_fast = hf_tokenizer_lib.from_pretrained("bert-base-uncased")
17-
dp = EnWik9().header(args.num_samples)
17+
dp = EnWik9().header(args.num_samples).batch(args.batch_size)
1818
samples = list(dp)
1919

2020
with Timer("Running TorchText BERT Tokenizer on non-batched input"):
21-
for s in samples:
22-
tt_tokenizer(s)
21+
for batch in samples:
22+
for s in batch:
23+
tt_tokenizer(s)
2324

2425
with Timer("Running HF BERT Tokenizer (slow) on non-batched input"):
25-
for s in samples:
26-
hf_tokenizer_slow.tokenize(s)
26+
for batch in samples:
27+
for s in batch:
28+
hf_tokenizer_slow.tokenize(s)
2729

2830
with Timer("Running HF BERT Tokenizer (fast) on non-batched input"):
29-
for s in samples:
30-
hf_tokenizer_fast.encode(s)
31+
for batch in samples:
32+
for s in batch:
33+
hf_tokenizer_fast.encode(s)
3134

3235
with Timer("Running TorchText BERT Tokenizer on batched input"):
33-
tt_tokenizer(samples)
36+
for batch in samples:
37+
tt_tokenizer(batch)
3438

3539
with Timer("Running HF BERT Tokenizer (fast) on batched input"):
36-
hf_tokenizer_fast.encode_batch(samples)
40+
for batch in samples:
41+
hf_tokenizer_fast.encode_batch(batch)
3742

3843

3944
if __name__ == "__main__":
4045
parser = ArgumentParser()
41-
parser.add_argument("--num-samples", default=1000, type=int)
46+
parser.add_argument("--num-samples", default=10000, type=int)
47+
parser.add_argument("--batch-size", default=100, type=int)
48+
4249
benchmark_bert_tokenizer(parser.parse_args())

0 commit comments

Comments
 (0)