From bcb330fd278b0624cd21f1240e201fc01f63e3a5 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Fri, 30 Sep 2022 22:18:42 -0700 Subject: [PATCH 1/4] add decoding capability to GPT2BPE tokenizer --- test/torchtext_unittest/test_transforms.py | 22 +++++++++++++++ torchtext/csrc/gpt2_bpe_tokenizer.cpp | 31 +++++++++++++++++++++- torchtext/csrc/gpt2_bpe_tokenizer.h | 3 +++ torchtext/csrc/register_pybindings.cpp | 1 + torchtext/csrc/register_torchbindings.cpp | 1 + torchtext/transforms.py | 10 +++++++ 6 files changed, 67 insertions(+), 1 deletion(-) diff --git a/test/torchtext_unittest/test_transforms.py b/test/torchtext_unittest/test_transforms.py index 995f3585d6..d514cc9701 100644 --- a/test/torchtext_unittest/test_transforms.py +++ b/test/torchtext_unittest/test_transforms.py @@ -364,11 +364,33 @@ def _gpt2_bpe_tokenizer(self, tokenizer): else: self.assertEqual(tokenizer(txt), expected_token_ids[idx]) + def _gpt2_bpe_decoder(self, tokenizer): + sample_ids = [ + ["15496", "2159", "28265", "703", "389", "345", "30"], + ["39", "2634", "297", "10205", "220", "22173", "129", "243", "75", "41585", "232", "126", "123"], + ["4965", "11377", "64", "2208", "72", "29625"], + ["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + ] + + expected_texts = [ + "Hello World!, how are you?", + "Hélló WoŕlḊ¿", + "Respublica superiorem", + "Avdija Vršajević în", + ] + + for idx, ids in enumerate(sample_ids): + self.assertEqual(tokenizer.decode(ids), expected_texts[idx]) + @nested_params([True, False], [True, False]) def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens): """test tokenization on single sentence input as well as batch on sentences""" self._gpt2_bpe_tokenizer(self._load_tokenizer(test_scripting=test_scripting, return_tokens=return_tokens)) + def test_gpt2_bpe_decoder(self): + """test string output returned by decoder given the token ids""" + self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False)) + def test_gpt2_bpe_tokenizer_save_load_pybind(self) -> None: tokenizer = self._load_tokenizer(test_scripting=False, return_tokens=False) tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt") diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index 7b85f5ba19..b1e2bdd910 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -2,6 +2,8 @@ #include // @manual #include +#include +#include #include #include #include @@ -147,7 +149,13 @@ GPT2BPEEncoder::GPT2BPEEncoder( bpe_merge_ranks_(std::move(bpe_merge_ranks)), byte_encoder_(std::move(byte_encoder)), seperator_(std::move(seperator)), - caching_enabled_(caching_enabled) {} + caching_enabled_(caching_enabled) { + for (auto const& x : bpe_encoder_) + bpe_decoder_.insert(x.value(), x.key()); + + for (auto const& x : byte_encoder_) + byte_decoder_.insert(x.value(), x.key()); +} GPT2BPEEncoder::GPT2BPEEncoder( const std::unordered_map& bpe_encoder, @@ -279,6 +287,27 @@ std::vector GPT2BPEEncoder::Encode(const std::string& text) { return bpe_token_ids; } +std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { + std::string text; + for (const auto token : tokens) { + // get unicode string for given integer key + const std::string str = bpe_decoder_.at(token); + // convert unicode string to wide string + std::wstring ws(str.size(), L' '); // Overestimate number of code points. + ws.resize(std::mbstowcs(&ws[0], str.c_str(), str.size())); // Shrink to fit. + // stup converter for wide char back to string + using convert_type = std::codecvt_utf8; + std::wstring_convert converter; + + for (wchar_t wchr : ws) { + // get character from byte decoder for each wide character + unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr)); + text += uchr; + } + } + return text; +} + std::vector GPT2BPEEncoder::Tokenize(const std::string& text) { std::vector bpe_tokens; for (const auto& token : PreTokenize_(text)) { diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.h b/torchtext/csrc/gpt2_bpe_tokenizer.h index 49777803a9..2d6e5dfbc9 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.h +++ b/torchtext/csrc/gpt2_bpe_tokenizer.h @@ -69,8 +69,10 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { public: const c10::Dict bpe_encoder_; + const c10::Dict bpe_decoder_; const c10::Dict bpe_merge_ranks_; const c10::Dict byte_encoder_; + const c10::Dict byte_decoder_; const std::string seperator_; const bool caching_enabled_; explicit GPT2BPEEncoder( @@ -99,6 +101,7 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { // --> result --> [707, 5927, 11, 707, 68] // TORCHTEXT_API std::vector Encode(const std::string& text); + TORCHTEXT_API std::string Decode(const std::vector& tokens); TORCHTEXT_API std::vector Tokenize(const std::string& text); TORCHTEXT_API std::unordered_map GetBPEEncoder() const; diff --git a/torchtext/csrc/register_pybindings.cpp b/torchtext/csrc/register_pybindings.cpp index 80f3591cf3..9da63a2311 100644 --- a/torchtext/csrc/register_pybindings.cpp +++ b/torchtext/csrc/register_pybindings.cpp @@ -179,6 +179,7 @@ PYBIND11_MODULE(_torchtext, m) { .def_property_readonly("byte_encoder_", &GPT2BPEEncoder::GetByteEncoder) .def("encode", &GPT2BPEEncoder::Encode) .def("tokenize", &GPT2BPEEncoder::Tokenize) + .def("decode", &GPT2BPEEncoder::Decode) .def(py::pickle( // __getstate__ [](const c10::intrusive_ptr& self) diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp index 64427f12e4..51c626b880 100644 --- a/torchtext/csrc/register_torchbindings.cpp +++ b/torchtext/csrc/register_torchbindings.cpp @@ -139,6 +139,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { c10::Dict, bool>()) .def("encode", &GPT2BPEEncoder::Encode) + .def("decode", &GPT2BPEEncoder::Decode) .def("tokenize", &GPT2BPEEncoder::Tokenize) .def_pickle( // __getstate__ diff --git a/torchtext/transforms.py b/torchtext/transforms.py index b2d90c88ac..b917c67ce9 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -383,6 +383,16 @@ def __prepare_scriptable__(self): return tokenizer_copy return self + def decode(self, tokens: List[str]) -> str: + """Return a decoded string given a list of string token ids. + + :param input: A list of strings, each string corresponds to token ids. + :type input: List[str] + :return: decoded text + :rtype: str + """ + return self.bpe.decode([int(token) for token in tokens]) + class CLIPTokenizer(Module): """ From 684acb4a6329e8271cda54a02cd4486ce16148ef Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Sat, 1 Oct 2022 17:17:48 -0700 Subject: [PATCH 2/4] use wstring_convert for all conversions --- torchtext/csrc/gpt2_bpe_tokenizer.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index b1e2bdd910..fa8117ffac 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -292,15 +292,13 @@ std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { for (const auto token : tokens) { // get unicode string for given integer key const std::string str = bpe_decoder_.at(token); - // convert unicode string to wide string - std::wstring ws(str.size(), L' '); // Overestimate number of code points. - ws.resize(std::mbstowcs(&ws[0], str.c_str(), str.size())); // Shrink to fit. - // stup converter for wide char back to string + // setup converter for converting wide char/string back to char/string using convert_type = std::codecvt_utf8; std::wstring_convert converter; + const std::wstring ws = converter.from_bytes(str); for (wchar_t wchr : ws) { - // get character from byte decoder for each wide character + // get output character from byte decoder for each wide character unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr)); text += uchr; } From 634f4262f34df6dafb6ca01416ac085d7db3c975 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Sat, 1 Oct 2022 17:22:44 -0700 Subject: [PATCH 3/4] minor update to comment and string creation logic --- torchtext/csrc/gpt2_bpe_tokenizer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index fa8117ffac..1a6bf4dce4 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -292,7 +292,7 @@ std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { for (const auto token : tokens) { // get unicode string for given integer key const std::string str = bpe_decoder_.at(token); - // setup converter for converting wide char/string back to char/string + // setup converter for converting wide chars to/from chars using convert_type = std::codecvt_utf8; std::wstring_convert converter; @@ -300,7 +300,7 @@ std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { for (wchar_t wchr : ws) { // get output character from byte decoder for each wide character unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr)); - text += uchr; + text.push_back(uchr); } } return text; From b64389ac69d9c5760a98e6ccbd5f1936b1791c44 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Sat, 1 Oct 2022 17:57:53 -0700 Subject: [PATCH 4/4] move converter definition outside of for loop --- torchtext/csrc/gpt2_bpe_tokenizer.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index 1a6bf4dce4..7a722d8aef 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -289,13 +289,13 @@ std::vector GPT2BPEEncoder::Encode(const std::string& text) { std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { std::string text; + // setup converter for converting wide chars to/from chars + using convert_type = std::codecvt_utf8; + std::wstring_convert converter; + for (const auto token : tokens) { // get unicode string for given integer key const std::string str = bpe_decoder_.at(token); - // setup converter for converting wide chars to/from chars - using convert_type = std::codecvt_utf8; - std::wstring_convert converter; - const std::wstring ws = converter.from_bytes(str); for (wchar_t wchr : ws) { // get output character from byte decoder for each wide character