diff --git a/test/torchtext_unittest/test_transforms.py b/test/torchtext_unittest/test_transforms.py index f3f2f50326..ee5fbea903 100644 --- a/test/torchtext_unittest/test_transforms.py +++ b/test/torchtext_unittest/test_transforms.py @@ -560,6 +560,140 @@ def _gpt2_bpe_decoder(self, tokenizer): for idx, ids in enumerate(sample_ids): self.assertEqual(tokenizer.decode(ids), expected_texts[idx]) + def _gpt2_bpe_decoder_with_special_tokens(self, tokenizer): + sample_ids = [ + [ + "27", + "91", + "437", + "1659", + "5239", + "91", + "29", + "290", + "1279", + "91", + "437", + "1659", + "5239", + "91", + "29", + "389", + "2041", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "318", + "407", + "0", + ], + [ + "9288", + "15859", + "8905", + "51", + "1279", + "615", + "603", + "62", + "4658", + "29", + "351", + "27196", + "24027", + "1279", + "91", + "437", + "1659", + "5239", + "91", + "29", + "290", + "8005", + "62", + "44710", + ], + ["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + [ + "40", + "423", + "281", + "16882", + "1359", + "428", + "318", + "257", + "1332", + "1279", + "91", + "437", + "1659", + "5239", + "91", + "29", + ], + ] + + expected_texts = [ + "<|endoftext|> and <|endoftext|> are special <|endofline|> is not!", + "test ACCEPT with DECLINE <|endoftext|> and NO_ACTION", + "Avdija Vršajević în", + "I have an inkling this is a test <|endoftext|>", + ] + + for idx, ids in enumerate(sample_ids): + self.assertEqual(tokenizer.decode(ids), expected_texts[idx]) + + newly_added = tokenizer.add_special_tokens( + special_tokens_dict={ + "unk_token": "<|endoftext|>", + "sep_token": "", + "additional_special_tokens": [ + "ACCEPT", + "DECLINE", + "inkling", + ], + } + ) + self.assertEqual(newly_added, 4) + + sample_ids = [ + [ + "50256", + "392", + "50256", + "533", + "2041", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "318", + "407", + "0", + ], + ["9288", "50258", "50257", "4480", "50259", "50256", "392", "8005", "62", "44710"], + ["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + ["40", "423", "281", "50260", "5661", "318", "257", "1332", "50256"], + ] + + expected_texts = [ + "<|endoftext|> and <|endoftext|> are special <|endofline|> is not!", + "test ACCEPT with DECLINE <|endoftext|> and NO_ACTION", + "Avdija Vršajević în", + "I have an inkling this is a test <|endoftext|>", + ] + + for idx, ids in enumerate(sample_ids): + self.assertEqual(tokenizer.decode(ids), expected_texts[idx]) + @nested_params([True, False], [True, False]) def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens): """test tokenization on single sentence input as well as batch on sentences""" @@ -568,6 +702,7 @@ def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens): def test_gpt2_bpe_decoder(self): """test string output returned by decoder given the token ids""" self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False)) + self._gpt2_bpe_decoder_with_special_tokens(self._load_tokenizer(test_scripting=False, return_tokens=False)) @nested_params([True, False]) def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens): diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index 77ae0b4e13..5b10fe4a73 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -384,8 +384,8 @@ std::vector GPT2BPEEncoder::PreTokenize_(std::string input) { std::vector GPT2BPEEncoder::Encode(const std::string& text) { std::vector bpe_token_ids; for (const auto& token : PreTokenize_(text)) { - if (added_tokens_encoder.contains(token)) { - bpe_token_ids.push_back(added_tokens_encoder.at(token)); + if (added_tokens_encoder_.contains(token)) { + bpe_token_ids.push_back(added_tokens_encoder_.at(token)); continue; } bool is_never_split_token = @@ -400,19 +400,67 @@ std::vector GPT2BPEEncoder::Encode(const std::string& text) { std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { std::string text; + bool is_prev_special = false; + bool is_current_special = false; // setup converter for converting wide chars to/from chars using convert_type = std::codecvt_utf8; std::wstring_convert converter; - for (const auto token : tokens) { - // get unicode string for given integer key - const std::string str = bpe_decoder_.at(token); - const std::wstring ws = converter.from_bytes(str); - for (wchar_t wchr : ws) { - // get output character from byte decoder for each wide character - unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr)); - text.push_back(uchr); + for (int tok_idx = 0; tok_idx < tokens.size(); tok_idx++) { + const auto token = tokens[tok_idx]; + std::string decoded_token; + + if (added_tokens_decoder_.contains(token)) { + // string is a special token from extended vocab + decoded_token = added_tokens_decoder_.at(token); + is_current_special = true; + } else { + const std::string str = bpe_decoder_.at(token); + if (bpe_never_split_set_.find(str) != bpe_never_split_set_.end()) { + // string is a special token from known vocab + decoded_token = str; + is_current_special = true; + } else { + // string is a regular token from known vocab + is_current_special = false; + const std::wstring ws = converter.from_bytes(str); + for (wchar_t wchr : ws) { + // get output character from byte decoder for each wide character + unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr)); + decoded_token.push_back(uchr); + } + } + } + + /* Fixing leading/trailing space(s) + + We need to ensure spaces before and after special tokens are removed + appropirately. Assuming <|endoftext|> and HELLO are special tokens: + string input: "<|endoftext|> <|endoftext|> and HELLO world !" + is to be tokenized as: + ['<|endoftext|>', '<|endoftext|>', 'and', 'HELLO', 'world', 'Ġ!'] + whereas an input like: + "<|endoftext|> and anything else!", gets tokenized as: + ['<|endoftext|>', 'and', 'Ġanything', 'Ġelse', '!'] + + Hence while decoding the corresponding string tokens back to + the original string text, we will have to insert those spaces back again. + - Add empty space before a special token if it is not at the begining of the + sentence and if it is not following another special token. + - Add empty space after a special token if it is not at the end of the + sentence. + */ + + // fix left space(s) for special tokens + if (is_current_special && (tok_idx > 0 && !is_prev_special)) { + text.push_back(' '); + } + text.append(decoded_token); + // fix right space(s) for special tokens + if (is_current_special && tok_idx != tokens.size() - 1) { + text.push_back(' '); } + is_prev_special = is_current_special; } return text; } @@ -436,31 +484,35 @@ int64_t GPT2BPEEncoder::AddSpecialTokens( int64_t newly_added = 0; /* All special tokens get added to `bpe_never_split_set_` set to avoid being - * split during tokenization. Tokens are added to `added_tokens_encoder` only - * if they are not already known (i.e. present in `bpe_encoder_`). + * split during tokenization. Tokens are added to `added_tokens_encoder_` only + * if they are not already known (i.e. not already present in `bpe_encoder_`). */ // Loop for standard tokens such as "bos_token", "eos_token", etc. for (auto const& token : standard_special_tokens_dict) { - if (added_tokens_encoder.contains(token.value())) { + if (added_tokens_encoder_.contains(token.value())) { continue; } bpe_never_split_set_.insert(token.value()); if (!bpe_encoder_.contains(token.value())) { - added_tokens_encoder.insert( - token.value(), bpe_encoder_.size() + added_tokens_encoder.size()); + added_tokens_encoder_.insert( + token.value(), bpe_encoder_.size() + added_tokens_encoder_.size()); + added_tokens_decoder_.insert( + bpe_decoder_.size() + added_tokens_decoder_.size(), token.value()); newly_added++; } } // Loop for any additional tokens for (auto const& token : additional_special_tokens) { - if (added_tokens_encoder.contains(token)) + if (added_tokens_encoder_.contains(token)) continue; bpe_never_split_set_.insert(token); if (!bpe_encoder_.contains(token)) { - added_tokens_encoder.insert( - token, bpe_encoder_.size() + added_tokens_encoder.size()); + added_tokens_encoder_.insert( + token, bpe_encoder_.size() + added_tokens_encoder_.size()); + added_tokens_decoder_.insert( + bpe_decoder_.size() + added_tokens_decoder_.size(), token); newly_added++; } } diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.h b/torchtext/csrc/gpt2_bpe_tokenizer.h index c13a15b202..8d7de4d6fc 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.h +++ b/torchtext/csrc/gpt2_bpe_tokenizer.h @@ -62,7 +62,8 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { std::string token, bool is_never_split_token); int64_t GetBPEMergeRank_(std::string pair); - c10::Dict added_tokens_encoder; + c10::Dict added_tokens_encoder_; + c10::Dict added_tokens_decoder_; protected: c10::Dict> cache_;