From 70230b9e0279a3689cce189d41a0b4cae8639967 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Thu, 29 Sep 2022 16:06:32 -0700 Subject: [PATCH 1/5] add_special_tokens and never split features added --- test/torchtext_unittest/test_transforms.py | 184 +++++++++++++++++++++ torchtext/csrc/gpt2_bpe_tokenizer.cpp | 161 +++++++++++++++--- torchtext/csrc/gpt2_bpe_tokenizer.h | 15 +- torchtext/csrc/register_pybindings.cpp | 10 ++ torchtext/csrc/register_torchbindings.cpp | 1 + torchtext/transforms.py | 32 +++- 6 files changed, 377 insertions(+), 26 deletions(-) diff --git a/test/torchtext_unittest/test_transforms.py b/test/torchtext_unittest/test_transforms.py index d514cc9701..f3f2f50326 100644 --- a/test/torchtext_unittest/test_transforms.py +++ b/test/torchtext_unittest/test_transforms.py @@ -336,6 +336,7 @@ def _gpt2_bpe_tokenizer(self, tokenizer): "Hélló WoŕlḊ¿", "Respublica superiorem", "Avdija Vršajević în", + "multi space", ] expected_tokens = [ @@ -343,12 +344,189 @@ def _gpt2_bpe_tokenizer(self, tokenizer): ["H", "é", "ll", "ó", "Ġ", "ĠWo", "Å", "ķ", "l", "á¸", "Ĭ", "Â", "¿"], ["Res", "public", "a", "Ġsuper", "i", "orem"], ["Av", "d", "ija", "ĠV", "r", "Å¡", "aj", "ev", "i", "Äĩ", "ĠÃ", "®", "n"], + ["multi", "Ġ", "Ġ", "Ġ", "Ġ", "Ġ", "Ġspace"], ] expected_token_ids = [ ["15496", "2159", "28265", "703", "389", "345", "30"], ["39", "2634", "297", "10205", "220", "22173", "129", "243", "75", "41585", "232", "126", "123"], ["4965", "11377", "64", "2208", "72", "29625"], ["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + ["41684", "220", "220", "220", "220", "220", "2272"], + ] + + # test batch of sentences + if tokenizer._return_tokens: + self.assertEqual(tokenizer(sample_texts), expected_tokens) + else: + self.assertEqual(tokenizer(sample_texts), expected_token_ids) + + # test individual sentences + for idx, txt in enumerate(sample_texts): + if tokenizer._return_tokens: + self.assertEqual(tokenizer(txt), expected_tokens[idx]) + else: + self.assertEqual(tokenizer(txt), expected_token_ids[idx]) + + def _gpt2_bpe_tokenizer_with_added_vocab(self, tokenizer): + sample_texts = [ + "<|endoftext|> and <|endoftext|> are special <|endofline|> is not!", + "test ACCEPT with DECLINE <|endoftext|> and NO_ACTION", + "none in vocab: <|endofline|> WALK_60M WALK_10M ", + "Respublica Vršajević în", + "some in vocab: <|endofline|> WALK_60M WALK_10M ", + "<|endoftext|> WALK_60M WALK_10M ", + ] + + newly_added = tokenizer.add_special_tokens( + special_tokens_dict={ + "unk_token": "<|endoftext|>", + "additional_special_tokens": [ + "ACCEPT", + "DECLINE", + "NO_ACTION", + "WALK_10M", + "WALK_60M", + "", + ], + } + ) + self.assertEqual(newly_added, 6) + + newly_added = tokenizer.add_special_tokens( + special_tokens_dict={ + "unk_token": "<|endoftext|>", + "sep_token": "", + "additional_special_tokens": [ + "ACCEPT", + "DECLINE", + "NO_ACTION", + "WALK_10M", + "WALK_60M", + "", + ], + } + ) + self.assertEqual(newly_added, 1) + + expected_tokens = [ + [ + "<|endoftext|>", + "and", + "<|endoftext|>", + "are", + "Ġspecial", + "Ġ<", + "|", + "end", + "of", + "line", + "|", + ">", + "Ġis", + "Ġnot", + "!", + ], + ["test", "ACCEPT", "", "with", "DECLINE", "<|endoftext|>", "and", "NO_ACTION"], + [ + "none", + "Ġin", + "Ġvoc", + "ab", + ":", + "Ġ<", + "|", + "end", + "of", + "line", + "|", + ">", + "WALK_60M", + "WALK_10M", + "<", + "state", + ">", + ], + ["Res", "public", "a", "ĠV", "r", "Å¡", "aj", "ev", "i", "Äĩ", "ĠÃ", "®", "n"], + [ + "some", + "Ġin", + "Ġvoc", + "ab", + ":", + "Ġ<", + "|", + "end", + "of", + "line", + "|", + ">", + "WALK_60M", + "WALK_10M", + "<", + "state", + ">", + ], + ["<|endoftext|>", "WALK_60M", "WALK_10M", "", "<", "state", ">"], + ] + expected_token_ids = [ + [ + "50256", + "392", + "50256", + "533", + "2041", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "318", + "407", + "0", + ], + ["9288", "50257", "50263", "4480", "50258", "50256", "392", "50259"], + [ + "23108", + "287", + "12776", + "397", + "25", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "50261", + "50260", + "27", + "5219", + "29", + ], + ["4965", "11377", "64", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"], + [ + "11246", + "287", + "12776", + "397", + "25", + "1279", + "91", + "437", + "1659", + "1370", + "91", + "29", + "50261", + "50260", + "27", + "5219", + "29", + ], + ["50256", "50261", "50260", "50262", "27", "5219", "29"], ] # test batch of sentences @@ -391,6 +569,12 @@ def test_gpt2_bpe_decoder(self): """test string output returned by decoder given the token ids""" self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False)) + @nested_params([True, False]) + def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens): + self._gpt2_bpe_tokenizer_with_added_vocab( + self._load_tokenizer(test_scripting=False, return_tokens=return_tokens) + ) + def test_gpt2_bpe_tokenizer_save_load_pybind(self) -> None: tokenizer = self._load_tokenizer(test_scripting=False, return_tokens=False) tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt") diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index 7a722d8aef..7d98ae0254 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include #include @@ -63,28 +65,94 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { // - ELSE, add token to return list std::string token; std::vector tokens; - re2::StringPiece inp(input); bool prepend_space = false; - while (kGPT2Regex.FindAndConsume(&inp, &token)) { - if (is_whitespace(token)) { - prepend_space = false; - if (inp.empty()) { // token is last token - tokens.push_back(token); - } else { - if (token.length() > 1) { - tokens.push_back(token.substr(0, token.length() - 1)); + std::vector index_matches; + + if (bpe_never_split_set_.size() > 0) { + std::string pattern = ""; + // escape regex characters for matching special tokens + for (std::string token : bpe_never_split_set_) { + std::string::size_type pos = 0; + while ((pos = token.find_first_of("|[]", pos)) != std::string::npos) { + switch (token[pos]) { + case '|': + token.replace(pos, 1, "\\|"); + pos += 2; + break; + case '[': + token.replace(pos, 1, "\\["); + pos += 2; + break; + case ']': + token.replace(pos, 1, "\\]"); + pos += 2; + break; } - if (token[token.length() - 1] == ' ') { // last char is space - prepend_space = true; - } else { // push last whitespace char as a token if it is not a space - tokens.push_back(token.substr(token.length() - 1)); + } + if (pattern.length() != 0) + pattern += "|"; + pattern += token; + } + + // break input into non-special and special parts + std::regex rx(pattern); + int64_t last_idx = 0; + for (auto it = std::sregex_iterator(input.begin(), input.end(), rx); + it != std::sregex_iterator(); + ++it) { + if (it->position() > last_idx) { + if (isspace(input[it->position() - 1])) { + // lstrip + index_matches.push_back( + input.substr(last_idx, it->position() - last_idx - 1)); + } else { + index_matches.push_back( + input.substr(last_idx, it->position() - last_idx)); } } - } else if (prepend_space) { - tokens.push_back(" " + token); - prepend_space = false; - } else { - tokens.push_back(token); + index_matches.push_back(input.substr(it->position(), it->length())); + last_idx = it->position() + it->length() + 1; + if (isspace(input[last_idx])) { + // rstrip + last_idx++; + } + } + if (last_idx < input.length() - 1) + index_matches.push_back( + input.substr(last_idx, input.length() - last_idx)); + } else { + index_matches.push_back(input); + } + + for (std::string index_token : index_matches) { + bool is_never_split_token = + bpe_never_split_set_.find(index_token) != bpe_never_split_set_.end(); + if (is_never_split_token) { + tokens.push_back(index_token); + continue; + } + re2::StringPiece inp(index_token); + while (kGPT2Regex.FindAndConsume(&inp, &token)) { + if (is_whitespace(token)) { + prepend_space = false; + if (inp.empty()) { // token is last token + tokens.push_back(token); + } else { + if (token.length() > 1) { + tokens.push_back(token.substr(0, token.length() - 1)); + } + if (token[token.length() - 1] == ' ') { // last char is space + prepend_space = true; + } else { // push last whitespace char as a token if it is not a space + tokens.push_back(token.substr(token.length() - 1)); + } + } + } else if (prepend_space) { + tokens.push_back(" " + token); + prepend_space = false; + } else { + tokens.push_back(token); + } } } return tokens; @@ -155,6 +223,8 @@ GPT2BPEEncoder::GPT2BPEEncoder( for (auto const& x : byte_encoder_) byte_decoder_.insert(x.value(), x.key()); + + added_to_vocab_tokens_count = 0; } GPT2BPEEncoder::GPT2BPEEncoder( @@ -170,11 +240,17 @@ GPT2BPEEncoder::GPT2BPEEncoder( _map_to_c10_dict(byte_encoder), caching_enabled) {} -std::vector GPT2BPEEncoder::ByteEncode_(std::string token) { +std::vector GPT2BPEEncoder::ByteEncode_( + std::string token, + bool is_never_split_token) { // Equivalent to: (self.byte_encoder[b] for b in token.encode('utf-8') std::vector encoded; - for (auto& ch : token) { - encoded.push_back(byte_encoder_.at((unsigned char)ch)); + if (is_never_split_token) { + encoded.push_back(token); + } else { + for (auto& ch : token) { + encoded.push_back(byte_encoder_.at((unsigned char)ch)); + } } return encoded; } @@ -279,7 +355,13 @@ std::vector GPT2BPEEncoder::PreTokenize_(std::string input) { std::vector GPT2BPEEncoder::Encode(const std::string& text) { std::vector bpe_token_ids; for (const auto& token : PreTokenize_(text)) { - auto byte_encoded_token = ByteEncode_(token); + if (added_tokens_encoder.contains(token)) { + bpe_token_ids.push_back(added_tokens_encoder.at(token)); + continue; + } + bool is_never_split_token = + bpe_never_split_set_.find(token) != bpe_never_split_set_.end(); + auto byte_encoded_token = ByteEncode_(token, is_never_split_token); for (const auto& bpe_token : BPE_(byte_encoded_token)) { bpe_token_ids.push_back(bpe_encoder_.at(bpe_token)); } @@ -309,7 +391,9 @@ std::string GPT2BPEEncoder::Decode(const std::vector& tokens) { std::vector GPT2BPEEncoder::Tokenize(const std::string& text) { std::vector bpe_tokens; for (const auto& token : PreTokenize_(text)) { - auto byte_encoded_token = ByteEncode_(token); + bool is_never_split_token = + bpe_never_split_set_.find(token) != bpe_never_split_set_.end(); + auto byte_encoded_token = ByteEncode_(token, is_never_split_token); for (const auto& bpe_token : BPE_(byte_encoded_token)) { bpe_tokens.push_back(bpe_token); } @@ -317,6 +401,37 @@ std::vector GPT2BPEEncoder::Tokenize(const std::string& text) { return bpe_tokens; } +int64_t GPT2BPEEncoder::AddSpecialTokens( + const c10::Dict& standard_special_tokens_dict, + const std::vector additional_special_tokens) { + int64_t newly_added = 0; + + for (auto const& token : standard_special_tokens_dict) { + if (added_tokens_encoder.contains(token.value())) + continue; + bpe_never_split_set_.insert(token.value()); + if (!bpe_encoder_.contains(token.value())) { + added_tokens_encoder.insert( + token.value(), bpe_encoder_.size() + added_tokens_encoder.size()); + newly_added++; + } + } + + for (auto const& token : additional_special_tokens) { + if (added_tokens_encoder.contains(token)) + continue; + bpe_never_split_set_.insert(token); + if (!bpe_encoder_.contains(token)) { + added_tokens_encoder.insert( + token, bpe_encoder_.size() + added_tokens_encoder.size()); + newly_added++; + } + } + + added_to_vocab_tokens_count += newly_added; + return newly_added; +} + std::unordered_map GPT2BPEEncoder::GetBPEEncoder() const { return _c10_dict_to_map(bpe_encoder_); } diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.h b/torchtext/csrc/gpt2_bpe_tokenizer.h index 2d6e5dfbc9..f39d431920 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.h +++ b/torchtext/csrc/gpt2_bpe_tokenizer.h @@ -1,10 +1,10 @@ #ifndef GPT2_BPE_TOKENIZER_H_ #define GPT2_BPE_TOKENIZER_H_ - #include #include #include +#include #include #include #include @@ -12,6 +12,9 @@ namespace torchtext { +// set to store tokens that are not to be split +static std::set bpe_never_split_set_; + typedef std::tuple< std::unordered_map, std::unordered_map, @@ -55,8 +58,13 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { private: const int64_t inf_; // Encode byte into an unicode character. - std::vector ByteEncode_(std::string token); + std::vector ByteEncode_( + std::string token, + bool is_never_split_token); int64_t GetBPEMergeRank_(std::string pair); + int64_t added_to_vocab_tokens_count; + // std::set bpe_never_split_set_; + c10::Dict added_tokens_encoder; protected: c10::Dict> cache_; @@ -103,6 +111,9 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { TORCHTEXT_API std::vector Encode(const std::string& text); TORCHTEXT_API std::string Decode(const std::vector& tokens); TORCHTEXT_API std::vector Tokenize(const std::string& text); + TORCHTEXT_API int64_t AddSpecialTokens( + const c10::Dict& standard_special_tokens_dict, + const std::vector additional_special_tokens); TORCHTEXT_API std::unordered_map GetBPEEncoder() const; TORCHTEXT_API std::unordered_map GetBPEMergeRanks() diff --git a/torchtext/csrc/register_pybindings.cpp b/torchtext/csrc/register_pybindings.cpp index 9da63a2311..cbca9c92a6 100644 --- a/torchtext/csrc/register_pybindings.cpp +++ b/torchtext/csrc/register_pybindings.cpp @@ -180,6 +180,16 @@ PYBIND11_MODULE(_torchtext, m) { .def("encode", &GPT2BPEEncoder::Encode) .def("tokenize", &GPT2BPEEncoder::Tokenize) .def("decode", &GPT2BPEEncoder::Decode) + .def( + "add_special_tokens", + [](const c10::intrusive_ptr& self, + const std::unordered_map& items, + const std::vector& additional) { + c10::Dict d; + for (const auto& item : items) + d.insert(item.first, item.second); + return (self->AddSpecialTokens(d, additional)); + }) .def(py::pickle( // __getstate__ [](const c10::intrusive_ptr& self) diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp index 51c626b880..25b23a04c2 100644 --- a/torchtext/csrc/register_torchbindings.cpp +++ b/torchtext/csrc/register_torchbindings.cpp @@ -141,6 +141,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("encode", &GPT2BPEEncoder::Encode) .def("decode", &GPT2BPEEncoder::Decode) .def("tokenize", &GPT2BPEEncoder::Tokenize) + .def("add_special_tokens", &GPT2BPEEncoder::AddSpecialTokens) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index b917c67ce9..14c60f285c 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from functools import lru_cache -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torchtext # noqa: F401 @@ -294,6 +294,16 @@ class GPT2BPETokenizer(Module): def __init__(self, encoder_json_path: str, vocab_bpe_path: str, return_tokens: bool = False) -> None: super().__init__() self._seperator = "\u0001" + self.SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] # load bpe encoder and bpe decoder with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f: bpe_encoder = json.load(f) @@ -349,6 +359,26 @@ def _tokenize(self, text: str) -> List[str]: """ return self.bpe.tokenize(text) + def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, List[str]]]) -> int: + """Add a dictionary of special tokens (eos, pad, cls…) to the encoder + + :param special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: + [bos_token, eos_token, unk_token, sep_token, pad_token, cls_token, mask_token, additional_special_tokens]. + Tokens are only added if they are not already in the vocabulary. + :type special_tokens_dict: Dict[str, Union[str, List[str]]] + :return: Number of tokens added to the vocabulary. + :rtype: int + """ + for key in special_tokens_dict.keys(): + assert ( + key in self.SPECIAL_TOKENS_ATTRIBUTES + ), f"Key '{key}' is not in the special token list: {self.SPECIAL_TOKENS_ATTRIBUTES}" + + return self.bpe.add_special_tokens( + {k: v for k, v in special_tokens_dict.items() if k != "additional_special_tokens"}, + special_tokens_dict.get("additional_special_tokens", []), + ) + def forward(self, input: Any) -> Any: """ :param input: Input sentence or list of sentences on which to apply tokenizer. From 4ceb641d155eab21b59a5364d94dee0d63c1bbf6 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Thu, 29 Sep 2022 18:49:27 -0700 Subject: [PATCH 2/5] removed a comment and updated a type hint --- torchtext/csrc/gpt2_bpe_tokenizer.h | 1 - torchtext/transforms.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.h b/torchtext/csrc/gpt2_bpe_tokenizer.h index f39d431920..e20cf2d06c 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.h +++ b/torchtext/csrc/gpt2_bpe_tokenizer.h @@ -63,7 +63,6 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { bool is_never_split_token); int64_t GetBPEMergeRank_(std::string pair); int64_t added_to_vocab_tokens_count; - // std::set bpe_never_split_set_; c10::Dict added_tokens_encoder; protected: diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 14c60f285c..3e9161fce8 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import torch import torchtext # noqa: F401 @@ -359,7 +359,7 @@ def _tokenize(self, text: str) -> List[str]: """ return self.bpe.tokenize(text) - def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, List[str]]]) -> int: + def add_special_tokens(self, special_tokens_dict: Mapping[str, Union[str, Sequence[str]]]) -> int: """Add a dictionary of special tokens (eos, pad, cls…) to the encoder :param special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: From 42a14a031648ccf025844a34ff5c19811f6fe7b0 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Thu, 29 Sep 2022 21:13:33 -0700 Subject: [PATCH 3/5] added explanation and example for how this change works --- torchtext/csrc/gpt2_bpe_tokenizer.cpp | 42 +++++++++++++++++++++++---- torchtext/csrc/gpt2_bpe_tokenizer.h | 1 - 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index 7d98ae0254..c02769702d 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -68,9 +68,35 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { bool prepend_space = false; std::vector index_matches; + /* Notes on handling Special Tokens: + We use regex pattern to first identify the special tokens in the input text. + Other non-special tokens go through pre-tokenization as usual, but special + tokens skip those steps. + + Steps: + * Loop over the set containing user-supplied strings that are to be treated as + special tokens. This set gets created through the calls to + `add_special_tokens` API. + - form a regex pattern that helps in extracting special tokens from the + input text. + * Crate a vector that contains chunks of input text, such that each chunk is + either a sequence of non-special token or a single special token. For example, + assuming <|special_tok|> and [SEP] are special tokens, the following text + "This is an example with <|special_tok|> and [SEP] and [SPAM]." + will get converted to a vector of strings: + ["This is an example with", "<|special_tok|>", "and", "[SEP]", "and + [SPAM]."] + - if the input does not contain any special tokens, the vector will just + contain a single token that is the whole original input text. + * For all of the tokens in the above vector, we proceed with BPE tokenization + as usual while skipping over certain steps as appropriate for special tokens. + */ + if (bpe_never_split_set_.size() > 0) { std::string pattern = ""; // escape regex characters for matching special tokens + // this is done to ensure character like '|' in special like + // <|endoftext|> don't get special regex meaning for (std::string token : bpe_never_split_set_) { std::string::size_type pos = 0; while ((pos = token.find_first_of("|[]", pos)) != std::string::npos) { @@ -102,7 +128,7 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { ++it) { if (it->position() > last_idx) { if (isspace(input[it->position() - 1])) { - // lstrip + // strip space on the left of the special token index_matches.push_back( input.substr(last_idx, it->position() - last_idx - 1)); } else { @@ -113,7 +139,7 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { index_matches.push_back(input.substr(it->position(), it->length())); last_idx = it->position() + it->length() + 1; if (isspace(input[last_idx])) { - // rstrip + // strip space on the right of the special token last_idx++; } } @@ -121,6 +147,7 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { index_matches.push_back( input.substr(last_idx, input.length() - last_idx)); } else { + // input does not have any special tokens index_matches.push_back(input); } @@ -128,6 +155,7 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { bool is_never_split_token = bpe_never_split_set_.find(index_token) != bpe_never_split_set_.end(); if (is_never_split_token) { + // skip the rest of pre-tokenization work for special tokens tokens.push_back(index_token); continue; } @@ -223,8 +251,6 @@ GPT2BPEEncoder::GPT2BPEEncoder( for (auto const& x : byte_encoder_) byte_decoder_.insert(x.value(), x.key()); - - added_to_vocab_tokens_count = 0; } GPT2BPEEncoder::GPT2BPEEncoder( @@ -406,6 +432,12 @@ int64_t GPT2BPEEncoder::AddSpecialTokens( const std::vector additional_special_tokens) { int64_t newly_added = 0; + /* All special tokens get added to `bpe_never_split_set_` set to avoid being + * split during tokenization. Tokens are added to `added_tokens_encoder` only + * if they are not already known (i.e. present in `bpe_encoder_`). + */ + + // Loop for standard tokens such as "bos_token", "eos_token", etc. for (auto const& token : standard_special_tokens_dict) { if (added_tokens_encoder.contains(token.value())) continue; @@ -417,6 +449,7 @@ int64_t GPT2BPEEncoder::AddSpecialTokens( } } + // Loop for any additional tokens for (auto const& token : additional_special_tokens) { if (added_tokens_encoder.contains(token)) continue; @@ -428,7 +461,6 @@ int64_t GPT2BPEEncoder::AddSpecialTokens( } } - added_to_vocab_tokens_count += newly_added; return newly_added; } diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.h b/torchtext/csrc/gpt2_bpe_tokenizer.h index e20cf2d06c..5fc0197b74 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.h +++ b/torchtext/csrc/gpt2_bpe_tokenizer.h @@ -62,7 +62,6 @@ struct GPT2BPEEncoder : torch::CustomClassHolder { std::string token, bool is_never_split_token); int64_t GetBPEMergeRank_(std::string pair); - int64_t added_to_vocab_tokens_count; c10::Dict added_tokens_encoder; protected: From 9e662917c20aa26730186942b8bca42aa0f122ce Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Fri, 30 Sep 2022 10:41:39 -0700 Subject: [PATCH 4/5] move SPECIAL_TOKENS_ATTRIBUTES to utils --- torchtext/transforms.py | 16 +++------------- torchtext/utils.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 3e9161fce8..c74f37c4f8 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -14,7 +14,7 @@ ) from torchtext._torchtext import RegexTokenizer as RegexTokenizerPybind from torchtext.data.functional import load_sp_model -from torchtext.utils import get_asset_local_path +from torchtext.utils import get_asset_local_path, SPECIAL_TOKENS_ATTRIBUTES from torchtext.vocab import Vocab from . import functional as F @@ -294,16 +294,6 @@ class GPT2BPETokenizer(Module): def __init__(self, encoder_json_path: str, vocab_bpe_path: str, return_tokens: bool = False) -> None: super().__init__() self._seperator = "\u0001" - self.SPECIAL_TOKENS_ATTRIBUTES = [ - "bos_token", - "eos_token", - "unk_token", - "sep_token", - "pad_token", - "cls_token", - "mask_token", - "additional_special_tokens", - ] # load bpe encoder and bpe decoder with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f: bpe_encoder = json.load(f) @@ -371,8 +361,8 @@ def add_special_tokens(self, special_tokens_dict: Mapping[str, Union[str, Sequen """ for key in special_tokens_dict.keys(): assert ( - key in self.SPECIAL_TOKENS_ATTRIBUTES - ), f"Key '{key}' is not in the special token list: {self.SPECIAL_TOKENS_ATTRIBUTES}" + key in SPECIAL_TOKENS_ATTRIBUTES + ), f"Key '{key}' is not in the special token list: {SPECIAL_TOKENS_ATTRIBUTES}" return self.bpe.add_special_tokens( {k: v for k, v in special_tokens_dict.items() if k != "additional_special_tokens"}, diff --git a/torchtext/utils.py b/torchtext/utils.py index a7910b222f..b73b6419c7 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -13,6 +13,17 @@ logger = logging.getLogger(__name__) +SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", +] + def reporthook(t): """ From ad249fc8be9aafaa088b559658cf3ca8f05f99b0 Mon Sep 17 00:00:00 2001 From: Sumit Kumar Date: Mon, 3 Oct 2022 23:11:28 -0700 Subject: [PATCH 5/5] rebase and address latest nit comments --- torchtext/csrc/gpt2_bpe_tokenizer.cpp | 10 +++++----- torchtext/transforms.py | 16 +++++++++++++--- torchtext/utils.py | 11 ----------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/torchtext/csrc/gpt2_bpe_tokenizer.cpp b/torchtext/csrc/gpt2_bpe_tokenizer.cpp index c02769702d..c54eca516e 100644 --- a/torchtext/csrc/gpt2_bpe_tokenizer.cpp +++ b/torchtext/csrc/gpt2_bpe_tokenizer.cpp @@ -70,7 +70,7 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { /* Notes on handling Special Tokens: We use regex pattern to first identify the special tokens in the input text. - Other non-special tokens go through pre-tokenization as usual, but special + Other 'non-special' tokens go through pre-tokenization as usual, but special tokens skip those steps. Steps: @@ -79,7 +79,7 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { `add_special_tokens` API. - form a regex pattern that helps in extracting special tokens from the input text. - * Crate a vector that contains chunks of input text, such that each chunk is + * Create a vector that contains chunks of input text, such that each chunk is either a sequence of non-special token or a single special token. For example, assuming <|special_tok|> and [SEP] are special tokens, the following text "This is an example with <|special_tok|> and [SEP] and [SPAM]." @@ -94,9 +94,9 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { if (bpe_never_split_set_.size() > 0) { std::string pattern = ""; - // escape regex characters for matching special tokens - // this is done to ensure character like '|' in special like - // <|endoftext|> don't get special regex meaning + // Escape regex characters for matching special tokens. This is done to + // ensure that characters like '|' in certain special tokens such as + // <|endoftext|> don't get special regex treatment. for (std::string token : bpe_never_split_set_) { std::string::size_type pos = 0; while ((pos = token.find_first_of("|[]", pos)) != std::string::npos) { diff --git a/torchtext/transforms.py b/torchtext/transforms.py index c74f37c4f8..2e71fa594e 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -14,7 +14,7 @@ ) from torchtext._torchtext import RegexTokenizer as RegexTokenizerPybind from torchtext.data.functional import load_sp_model -from torchtext.utils import get_asset_local_path, SPECIAL_TOKENS_ATTRIBUTES +from torchtext.utils import get_asset_local_path from torchtext.vocab import Vocab from . import functional as F @@ -288,6 +288,16 @@ class GPT2BPETokenizer(Module): :type return_input: bool """ + SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] __jit_unused_properties__ = ["is_jitable"] _seperator: torch.jit.Final[str] @@ -361,8 +371,8 @@ def add_special_tokens(self, special_tokens_dict: Mapping[str, Union[str, Sequen """ for key in special_tokens_dict.keys(): assert ( - key in SPECIAL_TOKENS_ATTRIBUTES - ), f"Key '{key}' is not in the special token list: {SPECIAL_TOKENS_ATTRIBUTES}" + key in self.SPECIAL_TOKENS_ATTRIBUTES + ), f"Key '{key}' is not in the special token list: {self.SPECIAL_TOKENS_ATTRIBUTES}" return self.bpe.add_special_tokens( {k: v for k, v in special_tokens_dict.items() if k != "additional_special_tokens"}, diff --git a/torchtext/utils.py b/torchtext/utils.py index b73b6419c7..a7910b222f 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -13,17 +13,6 @@ logger = logging.getLogger(__name__) -SPECIAL_TOKENS_ATTRIBUTES = [ - "bos_token", - "eos_token", - "unk_token", - "sep_token", - "pad_token", - "cls_token", - "mask_token", - "additional_special_tokens", -] - def reporthook(t): """