diff --git a/benchmark/benchmark_bert_tokenizer.py b/benchmark/benchmark_bert_tokenizer.py new file mode 100644 index 0000000000..597ff8e12d --- /dev/null +++ b/benchmark/benchmark_bert_tokenizer.py @@ -0,0 +1,42 @@ +from argparse import ArgumentParser + +from benchmark.utils import Timer +from tokenizers import Tokenizer as hf_tokenizer_lib +from torchtext.datasets import EnWik9 +from torchtext.transforms import BERTTokenizer as tt_bert_tokenizer +from transformers import BertTokenizer as hf_bert_tokenizer_slow + + +VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt" + + +def benchmark_bert_tokenizer(args): + tt_tokenizer = tt_bert_tokenizer(VOCAB_FILE, return_tokens=True) + hf_tokenizer_slow = hf_bert_tokenizer_slow.from_pretrained("bert-base-uncased") + hf_tokenizer_fast = hf_tokenizer_lib.from_pretrained("bert-base-uncased") + dp = EnWik9().header(args.num_samples) + samples = list(dp) + + with Timer("Running TorchText BERT Tokenizer on non-batched input"): + for s in samples: + tt_tokenizer(s) + + with Timer("Running HF BERT Tokenizer (slow) on non-batched input"): + for s in samples: + hf_tokenizer_slow.tokenize(s) + + with Timer("Running HF BERT Tokenizer (fast) on non-batched input"): + for s in samples: + hf_tokenizer_fast.encode(s) + + with Timer("Running TorchText BERT Tokenizer on batched input"): + tt_tokenizer(samples) + + with Timer("Running HF BERT Tokenizer (fast) on batched input"): + hf_tokenizer_fast.encode_batch(samples) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--num-samples", default=1000, type=int) + benchmark_bert_tokenizer(parser.parse_args()) diff --git a/torchtext/csrc/bert_tokenizer.cpp b/torchtext/csrc/bert_tokenizer.cpp index 4bade1153a..e4e205dcc2 100644 --- a/torchtext/csrc/bert_tokenizer.cpp +++ b/torchtext/csrc/bert_tokenizer.cpp @@ -292,6 +292,24 @@ std::vector BERTEncoder::Encode(std::string text) { return indices; } +std::vector> BERTEncoder::BatchTokenize( + std::vector text) { + std::vector> output; + for (const auto& t : text) { + output.push_back(Tokenize(t)); + } + return output; +} + +std::vector> BERTEncoder::BatchEncode( + std::vector text) { + std::vector> output; + for (const auto& t : text) { + output.push_back(Encode(t)); + } + return output; +} + BERTEncoderStates _serialize_bert_encoder( const c10::intrusive_ptr& self) { return std::make_tuple( diff --git a/torchtext/csrc/bert_tokenizer.h b/torchtext/csrc/bert_tokenizer.h index 145fc262c6..37313ac3be 100644 --- a/torchtext/csrc/bert_tokenizer.h +++ b/torchtext/csrc/bert_tokenizer.h @@ -21,6 +21,10 @@ struct BERTEncoder : torch::CustomClassHolder { c10::optional strip_accents); std::vector Tokenize(std::string text); std::vector Encode(std::string text); + std::vector> BatchTokenize( + std::vector text); + std::vector> BatchEncode(std::vector text); + Vocab vocab_; bool do_lower_case_; c10::optional strip_accents_ = {}; diff --git a/torchtext/csrc/register_pybindings.cpp b/torchtext/csrc/register_pybindings.cpp index 87c44d40e9..ce8a1297a7 100644 --- a/torchtext/csrc/register_pybindings.cpp +++ b/torchtext/csrc/register_pybindings.cpp @@ -220,6 +220,30 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init>()) .def("encode", &BERTEncoder::Encode) .def("tokenize", &BERTEncoder::Tokenize) + .def( + "batch_encode", + [](const c10::intrusive_ptr& self, + const py::list& items) { + std::vector input; + for (const auto& item : items) { + Py_ssize_t length; + const char* buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + input.push_back(std::string(buffer)); + } + return self->BatchEncode(input); + }) + .def( + "batch_tokenize", + [](const c10::intrusive_ptr& self, + const py::list& items) { + std::vector input; + for (const auto& item : items) { + Py_ssize_t length; + const char* buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + input.push_back(std::string(buffer)); + } + return self->BatchTokenize(input); + }) .def(py::pickle( // __getstate__ [](const c10::intrusive_ptr& self) -> BERTEncoderStates { diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp index 57b2d8bcd6..6b03cfea53 100644 --- a/torchtext/csrc/register_torchbindings.cpp +++ b/torchtext/csrc/register_torchbindings.cpp @@ -177,6 +177,18 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def(torch::init>()) .def("encode", &BERTEncoder::Encode) .def("tokenize", &BERTEncoder::Tokenize) + .def( + "batch_encode", + [](const c10::intrusive_ptr& self, + const std::vector& items) { + return self->BatchEncode(items); + }) + .def( + "batch_tokenize", + [](const c10::intrusive_ptr& self, + const std::vector& items) { + return self->BatchTokenize(items); + }) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) -> BERTEncoderStates { diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 5879e44a2d..b5f879fe8d 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -597,6 +597,13 @@ def _encode(self, text: str) -> List[str]: tokens_ids_str: List[str] = [str(token_id) for token_id in token_ids] return tokens_ids_str + @torch.jit.export + def _batch_encode(self, text: List[str]) -> List[List[str]]: + """Batch version of _encode i.e operate on list of str""" + token_ids: List[List[int]] = self.bert_model.batch_encode([t.strip() for t in text]) + tokens_ids_str: List[List[str]] = [[str(t) for t in token_id] for token_id in token_ids] + return tokens_ids_str + @torch.jit.export def _tokenize(self, text: str) -> List[str]: """Tokenize text into a list of tokens @@ -612,6 +619,11 @@ def _tokenize(self, text: str) -> List[str]: """ return self.bert_model.tokenize(text.strip()) + @torch.jit.export + def _batch_tokenize(self, text: List[str]) -> List[List[str]]: + """Batch version of _tokenize i.e operate on list of str""" + return self.bert_model.batch_tokenize([t.strip() for t in text]) + def forward(self, input: Any) -> Any: """ :param input: Input sentence or list of sentences on which to apply tokenizer. @@ -621,11 +633,10 @@ def forward(self, input: Any) -> Any: """ if torch.jit.isinstance(input, List[str]): tokens: List[List[str]] = [] - for text in input: - if self._return_tokens: - tokens.append(self._tokenize(text)) - else: - tokens.append(self._encode(text)) + if self._return_tokens: + tokens = self._batch_tokenize(input) + else: + tokens = self._batch_encode(input) return tokens elif torch.jit.isinstance(input, str): if self._return_tokens: