diff --git a/test/torchtext_unittest/test_transforms.py b/test/torchtext_unittest/test_transforms.py index d514cc9701..f3f2f50326 100644 --- a/test/torchtext_unittest/test_transforms.py +++ b/test/torchtext_unittest/test_transforms.py @@ -336,6 +336,7 @@ def _gpt2_bpe_tokenizer(self, tokenizer): "Hélló WoŕlḊ¿", "Respublica superiorem", "Avdija Vršajević în", + "multi space", ] expected_tokens = [ @@ -343,12 +344,189 @@ def _gpt2_bpe_tokenizer(self, tokenizer): ["H", "é", "ll", "ó", "Ġ", "ĠWo", "Å", "ķ", "l", "á¸", "Ĭ", "Â", "¿"], ["Res", "public", "a", "Ġsuper", "i", "orem"], ["Av", "d", "ija", "ĠV", "r", "Å¡", "aj", "ev", "i", "Äĩ", "ĠÃ", "®", "n"], + ["multi", "Ġ", "Ġ", "Ġ", "Ġ", "Ġ", "Ġspace"], ] expected_token_ids = [ ["15496", "2159", "28265", "703", "389", "345", "30"], ["39", "2634", "297", "10205", "220", "22173", "129", "243", "75", "41585", "232", "126", "123"], ["4965", "11377", "64", "2208", "72", "29625"], ["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + ["41684", "220", "220", "220", "220", "220", "2272"], + ] + + # test batch of sentences + if tokenizer._return_tokens: + self.assertEqual(tokenizer(sample_texts), expected_tokens) + else: + self.assertEqual(tokenizer(sample_texts), expected_token_ids) + + # test individual sentences + for idx, txt in enumerate(sample_texts): + if tokenizer._return_tokens: + self.assertEqual(tokenizer(txt), expected_tokens[idx]) + else: + self.assertEqual(tokenizer(txt), expected_token_ids[idx]) + + def _gpt2_bpe_tokenizer_with_added_vocab(self, tokenizer): + sample_texts = [ + "<|endoftext|> and <|endoftext|> are special <|endofline|> is not!", + "test ACCEPT with DECLINE <|endoftext|> and NO_ACTION", + "none in vocab: <|endofline|> WALK_60M WALK_10M ", + "Respublica Vršajević în", + "some in vocab: <|endofline|> WALK_60M WALK_10M ", + "<|endoftext|> WALK_60M WALK_10M ", + ] + + newly_added = tokenizer.add_special_tokens( + special_tokens_dict={ + "unk_token": "<|endoftext|>", + "additional_special_tokens": [ + "ACCEPT", + "DECLINE", + "NO_ACTION", + "WALK_10M", + "WALK_60M", + "", + ], + } + ) + self.assertEqual(newly_added, 6) + + newly_added = tokenizer.add_special_tokens( + special_tokens_dict={ + "unk_token": "<|endoftext|>", + "sep_token": "", + "additional_special_tokens": [ + "ACCEPT", + "DECLINE", + "NO_ACTION", + "WALK_10M", + "WALK_60M", + "", + ], + } + ) + self.assertEqual(newly_added, 1) + + expected_tokens = [ + [ + "<|endoftext|>", + "and", + "<|endoftext|>", + "are", + "Ġspecial", + "Ġ<", + "|", + "end", + "of", + "line", + "|", + ">", + "Ġis", + "Ġnot", + "!", + ], + ["test", "ACCEPT", "", "with", "DECLINE", "<|endoftext|>", "and", "NO_ACTION"], + [ + "none", + "Ġin", + "Ġvoc", + "ab", + ":", + "Ġ<", + "|", + "end", + "of", + "line", + "|", + ">", + "WALK_60M", + "WALK_10M", + "<", + "state", + ">", + ], + ["Res", "public", "a", "ĠV", "r", "Å¡", "aj", "ev", "i", "Äĩ", "ĠÃ", "®", "n"], + [ + "some", + "Ġin", + "Ġvoc", + "ab", + ":", + "Ġ<", + "|", + "end", + "of", + "line", + "|", + ">", + "WALK_60M", + "WALK_10M", + "<", + "state", + ">", + ], + ["<|endoftext|>", "WALK_60M", "WALK_10M", "", "<", "state", ">"], + ] + expected_token_ids = [ + [ + "50256", + "392", + "50256", + "533", + "2041", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "318", + "407", + "0", + ], + ["9288", "50257", "50263", "4480", "50258", "50256", "392", "50259"], + [ + "23108", + "287", + "12776", + "397", + "25", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "50261", + "50260", + "27", + "5219", + "29", + ], + ["4965", "11377", "64", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + [ + "11246", + "287", + "12776", + "397", + "25", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "50261", + "50260", + "27", + "5219", + "29", + ], + ["50256", "50261", "50260", "50262", "27", "5219", "29"], ] # test batch of sentences @@ -391,6 +569,12 @@ def test_gpt2_bpe_decoder(self): """test string output returned by decoder given the token ids""" self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False)) + @nested_params([True, False]) + def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens): + self._gpt2_bpe_tokenizer_with_added_vocab( + self._load_tokenizer(test_scripting=False, return_tokens=return_tokens) + ) + def test_gpt2_bpe_tokenizer_save_load_pybind(self) -> None: tokenizer = self._load_tokenizer(test_scripting=False, return_tokens=False) tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt") diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index 7a722d8aef..c54eca516e 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include #include @@ -63,28 +65,122 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { // - ELSE, add token to return list std::string token; std::vector tokens; - re2::StringPiece inp(input); bool prepend_space = false; - while (kGPT2Regex.FindAndConsume(&inp, &token)) { - if (is_whitespace(token)) { - prepend_space = false; - if (inp.empty()) { // token is last token - tokens.push_back(token); - } else { - if (token.length() > 1) { - tokens.push_back(token.substr(0, token.length() - 1)); + std::vector index_matches; + + /* Notes on handling Special Tokens: + We use regex pattern to first identify the special tokens in the input text. + Other 'non-special' tokens go through pre-tokenization as usual, but special + tokens skip those steps. + + Steps: + * Loop over the set containing user-supplied strings that are to be treated as + special tokens. This set gets created through the calls to + `add_special_tokens` API. + - form a regex pattern that helps in extracting special tokens from the + input text. + * Create a vector that contains chunks of input text, such that each chunk is + either a sequence of non-special token or a single special token. For example, + assuming <|special_tok|> and [SEP] are special tokens, the following text + "This is an example with <|special_tok|> and [SEP] and [SPAM]." + will get converted to a vector of strings: + ["This is an example with", "<|special_tok|>", "and", "[SEP]", "and + [SPAM]."] + - if the input does not contain any special tokens, the vector will just + contain a single token that is the whole original input text. + * For all of the tokens in the above vector, we proceed with BPE tokenization + as usual while skipping over certain steps as appropriate for special tokens. + */ + + if (bpe_never_split_set_.size() > 0) { + std::string pattern = ""; + // Escape regex characters for matching special tokens. This is done to + // ensure that characters like '|' in certain special tokens such as + // <|endoftext|> don't get special regex treatment. + for (std::string token : bpe_never_split_set_) { + std::string::size_type pos = 0; + while ((pos = token.find_first_of("|[]", pos)) != std::string::npos) { + switch (token[pos]) { + case '|': + token.replace(pos, 1, "\\|"); + pos += 2; + break; + case '[': + token.replace(pos, 1, "\\["); + pos += 2; + break; + case ']': + token.replace(pos, 1, "\\]"); + pos += 2; + break; } - if (token[token.length() - 1] == ' ') { // last char is space - prepend_space = true; - } else { // push last whitespace char as a token if it is not a space - tokens.push_back(token.substr(token.length() - 1)); + } + if (pattern.length() != 0) + pattern += "|"; + pattern += token; + } + + // break input into non-special and special parts + std::regex rx(pattern); + int64_t last_idx = 0; + for (auto it = std::sregex_iterator(input.begin(), input.end(), rx); + it != std::sregex_iterator(); + ++it) { + if (it->position() > last_idx) { + if (isspace(input[it->position() - 1])) { + // strip space on the left of the special token + index_matches.push_back( + input.substr(last_idx, it->position() - last_idx - 1)); + } else { + index_matches.push_back( + input.substr(last_idx, it->position() - last_idx)); } } - } else if (prepend_space) { - tokens.push_back(" " + token); - prepend_space = false; - } else { - tokens.push_back(token); + index_matches.push_back(input.substr(it->position(), it->length())); + last_idx = it->position() + it->length() + 1; + if (isspace(input[last_idx])) { + // strip space on the right of the special token + last_idx++; + } + } + if (last_idx < input.length() - 1) + index_matches.push_back( + input.substr(last_idx, input.length() - last_idx)); + } else { + // input does not have any special tokens + index_matches.push_back(input); + } + + for (std::string index_token : index_matches) { + bool is_never_split_token = + bpe_never_split_set_.find(index_token) != bpe_never_split_set_.end(); + if (is_never_split_token) { + // skip the rest of pre-tokenization work for special tokens + tokens.push_back(index_token); + continue; + } + re2::StringPiece inp(index_token); + while (kGPT2Regex.FindAndConsume(&inp, &token)) { + if (is_whitespace(token)) { + prepend_space = false; + if (inp.empty()) { // token is last token + tokens.push_back(token); + } else { + if (token.length() > 1) { + tokens.push_back(token.substr(0, token.length() - 1)); + } + if (token[token.length() - 1] == ' ') { // last char is space + prepend_space = true; + } else { // push last whitespace char as a token if it is not a space + tokens.push_back(token.substr(token.length() - 1)); + } + } + } else if (prepend_space) { + tokens.push_back(" " + token); + prepend_space = false; + } else { + tokens.push_back(token); + } } } return tokens; @@ -170,11 +266,17 @@ GPT2BPEEncoder::GPT2BPEEncoder( _map_to_c10_dict(byte_encoder), caching_enabled) {} -std::vector GPT2BPEEncoder::ByteEncode_(std::string token) { +std::vector GPT2BPEEncoder::ByteEncode_( + std::string token, + bool is_never_split_token) { // Equivalent to: (self.byte_encoder[b] for b in token.encode('utf-8') std::vector encoded; - for (auto& ch : token) { - encoded.push_back(byte_encoder_.at((unsigned char)ch)); + if (is_never_split_token) { + encoded.push_back(token); + } else { + for (auto& ch : token) { + encoded.push_back(byte_encoder_.at((unsigned char)ch)); + } } return encoded; } @@ -279,7 +381,13 @@ std::vector GPT2BPEEncoder::PreTokenize_(std::string input) { std::vector GPT2BPEEncoder::Encode(const std::string& text) { std::vector bpe_token_ids; for (const auto& token : PreTokenize_(text)) { - auto byte_encoded_token = ByteEncode_(token); + if (added_tokens_encoder.contains(token)) { + bpe_token_ids.push_back(added_tokens_encoder.at(token)); + continue; + } + bool is_never_split_token = + bpe_never_split_set_.find(token) != bpe_never_split_set_.end(); + auto byte_encoded_token = ByteEncode_(token, is_never_split_token); for (const auto& bpe_token : BPE_(byte_encoded_token)) { bpe_token_ids.push_back(bpe_encoder_.at(bpe_token)); } @@ -309,7 +417,9 @@ std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { std::vector GPT2BPEEncoder::Tokenize(const std::string& text) { std::vector bpe_tokens; for (const auto& token : PreTokenize_(text)) { - auto byte_encoded_token = ByteEncode_(token); + bool is_never_split_token = + bpe_never_split_set_.find(token) != bpe_never_split_set_.end(); + auto byte_encoded_token = ByteEncode_(token, is_never_split_token); for (const auto& bpe_token : BPE_(byte_encoded_token)) { bpe_tokens.push_back(bpe_token); } @@ -317,6 +427,43 @@ std::vector GPT2BPEEncoder::Tokenize(const std::string& text) { return bpe_tokens; } +int64_t GPT2BPEEncoder::AddSpecialTokens( + const c10::Dict& standard_special_tokens_dict, + const std::vector additional_special_tokens) { + int64_t newly_added = 0; + + /* All special tokens get added to `bpe_never_split_set_` set to avoid being + * split during tokenization. Tokens are added to `added_tokens_encoder` only + * if they are not already known (i.e. present in `bpe_encoder_`). + */ + + // Loop for standard tokens such as "bos_token", "eos_token", etc. + for (auto const& token : standard_special_tokens_dict) { + if (added_tokens_encoder.contains(token.value())) + continue; + bpe_never_split_set_.insert(token.value()); + if (!bpe_encoder_.contains(token.value())) { + added_tokens_encoder.insert( + token.value(), bpe_encoder_.size() + added_tokens_encoder.size()); + newly_added++; + } + } + + // Loop for any additional tokens + for (auto const& token : additional_special_tokens) { + if (added_tokens_encoder.contains(token)) + continue; + bpe_never_split_set_.insert(token); + if (!bpe_encoder_.contains(token)) { + added_tokens_encoder.insert( + token, bpe_encoder_.size() + added_tokens_encoder.size()); + newly_added++; + } + } + + return newly_added; +} + std::unordered_map GPT2BPEEncoder::GetBPEEncoder() const { return _c10_dict_to_map(bpe_encoder_); } diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.h b/torchtext/csrc/gpt2_bpe_tokenizer.h index 2d6e5dfbc9..5fc0197b74 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.h +++ b/torchtext/csrc/gpt2_bpe_tokenizer.h @@ -1,10 +1,10 @@ #ifndef GPT2_BPE_TOKENIZER_H_ #define GPT2_BPE_TOKENIZER_H_ - #include #include #include +#include #include #include #include @@ -12,6 +12,9 @@ namespace torchtext { +// set to store tokens that are not to be split +static std::set bpe_never_split_set_; + typedef std::tuple< std::unordered_map, std::unordered_map, @@ -55,8 +58,11 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { private: const int64_t inf_; // Encode byte into an unicode character. - std::vector ByteEncode_(std::string token); + std::vector ByteEncode_( + std::string token, + bool is_never_split_token); int64_t GetBPEMergeRank_(std::string pair); + c10::Dict added_tokens_encoder; protected: c10::Dict> cache_; @@ -103,6 +109,9 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { TORCHTEXT_API std::vector Encode(const std::string& text); TORCHTEXT_API std::string Decode(const std::vector& tokens); TORCHTEXT_API std::vector Tokenize(const std::string& text); + TORCHTEXT_API int64_t AddSpecialTokens( + const c10::Dict& standard_special_tokens_dict, + const std::vector additional_special_tokens); TORCHTEXT_API std::unordered_map GetBPEEncoder() const; TORCHTEXT_API std::unordered_map GetBPEMergeRanks() diff --git a/torchtext/csrc/register_pybindings.cpp b/torchtext/csrc/register_pybindings.cpp index 9da63a2311..cbca9c92a6 100644 --- a/torchtext/csrc/register_pybindings.cpp +++ b/torchtext/csrc/register_pybindings.cpp @@ -180,6 +180,16 @@ PYBIND11_MODULE(_torchtext, m) { .def("encode", &GPT2BPEEncoder::Encode) .def("tokenize", &GPT2BPEEncoder::Tokenize) .def("decode", &GPT2BPEEncoder::Decode) + .def( + "add_special_tokens", + [](const c10::intrusive_ptr& self, + const std::unordered_map& items, + const std::vector& additional) { + c10::Dict d; + for (const auto& item : items) + d.insert(item.first, item.second); + return (self->AddSpecialTokens(d, additional)); + }) .def(py::pickle( // __getstate__ [](const c10::intrusive_ptr& self) diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp index 51c626b880..25b23a04c2 100644 --- a/torchtext/csrc/register_torchbindings.cpp +++ b/torchtext/csrc/register_torchbindings.cpp @@ -141,6 +141,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("encode", &GPT2BPEEncoder::Encode) .def("decode", &GPT2BPEEncoder::Decode) .def("tokenize", &GPT2BPEEncoder::Tokenize) + .def("add_special_tokens", &GPT2BPEEncoder::AddSpecialTokens) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index b917c67ce9..2e71fa594e 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from functools import lru_cache -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import torch import torchtext # noqa: F401 @@ -288,6 +288,16 @@ class GPT2BPETokenizer(Module): :type return_input: bool """ + SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] __jit_unused_properties__ = ["is_jitable"] _seperator: torch.jit.Final[str] @@ -349,6 +359,26 @@ def _tokenize(self, text: str) -> List[str]: """ return self.bpe.tokenize(text) + def add_special_tokens(self, special_tokens_dict: Mapping[str, Union[str, Sequence[str]]]) -> int: + """Add a dictionary of special tokens (eos, pad, cls…) to the encoder + + :param special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: + [bos_token, eos_token, unk_token, sep_token, pad_token, cls_token, mask_token, additional_special_tokens]. + Tokens are only added if they are not already in the vocabulary. + :type special_tokens_dict: Dict[str, Union[str, List[str]]] + :return: Number of tokens added to the vocabulary. + :rtype: int + """ + for key in special_tokens_dict.keys(): + assert ( + key in self.SPECIAL_TOKENS_ATTRIBUTES + ), f"Key '{key}' is not in the special token list: {self.SPECIAL_TOKENS_ATTRIBUTES}" + + return self.bpe.add_special_tokens( + {k: v for k, v in special_tokens_dict.items() if k != "additional_special_tokens"}, + special_tokens_dict.get("additional_special_tokens", []), + ) + def forward(self, input: Any) -> Any: """ :param input: Input sentence or list of sentences on which to apply tokenizer.