Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions test/torchtext_unittest/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,140 @@ def _gpt2_bpe_decoder(self, tokenizer):
for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

def _gpt2_bpe_decoder_with_special_tokens(self, tokenizer):
sample_ids = [
[
"27",
"91",
"437",
"1659",
"5239",
"91",
"29",
"290",
"1279",
"91",
"437",
"1659",
"5239",
"91",
"29",
"389",
"2041",
"1279",
"91",
"437",
"1659",
"1370",
"91",
"29",
"318",
"407",
"0",
],
[
"9288",
"15859",
"8905",
"51",
"1279",
"615",
"603",
"62",
"4658",
"29",
"351",
"27196",
"24027",
"1279",
"91",
"437",
"1659",
"5239",
"91",
"29",
"290",
"8005",
"62",
"44710",
],
["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"],
[
"40",
"423",
"281",
"16882",
"1359",
"428",
"318",
"257",
"1332",
"1279",
"91",
"437",
"1659",
"5239",
"91",
"29",
],
]

expected_texts = [
"<|endoftext|> and <|endoftext|> are special <|endofline|> is not!",
"test ACCEPT <avail_actions> with DECLINE <|endoftext|> and NO_ACTION",
"Avdija Vršajević în",
"I have an inkling this is a test <|endoftext|>",
]

for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

newly_added = tokenizer.add_special_tokens(
special_tokens_dict={
"unk_token": "<|endoftext|>",
"sep_token": "<avail_actions>",
"additional_special_tokens": [
"ACCEPT",
"DECLINE",
"inkling",
],
}
)
self.assertEqual(newly_added, 4)

sample_ids = [
[
"50256",
"392",
"50256",
"533",
"2041",
"1279",
"91",
"437",
"1659",
"1370",
"91",
"29",
"318",
"407",
"0",
],
["9288", "50258", "50257", "4480", "50259", "50256", "392", "8005", "62", "44710"],
["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"],
["40", "423", "281", "50260", "5661", "318", "257", "1332", "50256"],
]

expected_texts = [
"<|endoftext|> and <|endoftext|> are special <|endofline|> is not!",
"test ACCEPT <avail_actions> with DECLINE <|endoftext|> and NO_ACTION",
"Avdija Vršajević în",
"I have an inkling this is a test <|endoftext|>",
]

for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

@nested_params([True, False], [True, False])
def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
"""test tokenization on single sentence input as well as batch on sentences"""
Expand All @@ -568,6 +702,7 @@ def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
def test_gpt2_bpe_decoder(self):
"""test string output returned by decoder given the token ids"""
self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False))
self._gpt2_bpe_decoder_with_special_tokens(self._load_tokenizer(test_scripting=False, return_tokens=False))

@nested_params([True, False])
def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens):
Expand Down
88 changes: 70 additions & 18 deletions torchtext/csrc/gpt2_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ std::vector<std::string> GPT2BPEEncoder::PreTokenize_(std::string input) {
std::vector<int64_t> GPT2BPEEncoder::Encode(const std::string& text) {
std::vector<int64_t> bpe_token_ids;
for (const auto& token : PreTokenize_(text)) {
if (added_tokens_encoder.contains(token)) {
bpe_token_ids.push_back(added_tokens_encoder.at(token));
if (added_tokens_encoder_.contains(token)) {
bpe_token_ids.push_back(added_tokens_encoder_.at(token));
continue;
}
bool is_never_split_token =
Expand All @@ -400,19 +400,67 @@ std::vector<int64_t> GPT2BPEEncoder::Encode(const std::string& text) {

std::string GPT2BPEEncoder::Decode(const std::vector<int64_t>& tokens) {
std::string text;
bool is_prev_special = false;
bool is_current_special = false;
// setup converter for converting wide chars to/from chars
using convert_type = std::codecvt_utf8<wchar_t>;
std::wstring_convert<convert_type, wchar_t> converter;

for (const auto token : tokens) {
// get unicode string for given integer key
const std::string str = bpe_decoder_.at(token);
const std::wstring ws = converter.from_bytes(str);
for (wchar_t wchr : ws) {
// get output character from byte decoder for each wide character
unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr));
text.push_back(uchr);
for (int tok_idx = 0; tok_idx < tokens.size(); tok_idx++) {
const auto token = tokens[tok_idx];
std::string decoded_token;

if (added_tokens_decoder_.contains(token)) {
// string is a special token from extended vocab
decoded_token = added_tokens_decoder_.at(token);
is_current_special = true;
} else {
const std::string str = bpe_decoder_.at(token);
if (bpe_never_split_set_.find(str) != bpe_never_split_set_.end()) {
// string is a special token from known vocab
decoded_token = str;
is_current_special = true;
} else {
// string is a regular token from known vocab
is_current_special = false;
const std::wstring ws = converter.from_bytes(str);
for (wchar_t wchr : ws) {
// get output character from byte decoder for each wide character
unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr));
decoded_token.push_back(uchr);
}
}
}

/* Fixing leading/trailing space(s)

We need to ensure spaces before and after special tokens are removed
appropirately. Assuming <|endoftext|> and HELLO are special tokens:
string input: "<|endoftext|> <|endoftext|> and HELLO world !"
is to be tokenized as:
['<|endoftext|>', '<|endoftext|>', 'and', 'HELLO', 'world', 'Ġ!']
whereas an input like:
"<|endoftext|> and anything else!", gets tokenized as:
['<|endoftext|>', 'and', 'Ġanything', 'Ġelse', '!']

Hence while decoding the corresponding string tokens back to
the original string text, we will have to insert those spaces back again.
- Add empty space before a special token if it is not at the begining of the
sentence and if it is not following another special token.
- Add empty space after a special token if it is not at the end of the
sentence.
*/

// fix left space(s) for special tokens
if (is_current_special && (tok_idx > 0 && !is_prev_special)) {
text.push_back(' ');
}
text.append(decoded_token);
// fix right space(s) for special tokens
if (is_current_special && tok_idx != tokens.size() - 1) {
text.push_back(' ');
}
is_prev_special = is_current_special;
}
return text;
}
Expand All @@ -436,31 +484,35 @@ int64_t GPT2BPEEncoder::AddSpecialTokens(
int64_t newly_added = 0;

/* All special tokens get added to `bpe_never_split_set_` set to avoid being
* split during tokenization. Tokens are added to `added_tokens_encoder` only
* if they are not already known (i.e. present in `bpe_encoder_`).
* split during tokenization. Tokens are added to `added_tokens_encoder_` only
* if they are not already known (i.e. not already present in `bpe_encoder_`).
*/

// Loop for standard tokens such as "bos_token", "eos_token", etc.
for (auto const& token : standard_special_tokens_dict) {
if (added_tokens_encoder.contains(token.value())) {
if (added_tokens_encoder_.contains(token.value())) {
continue;
}
bpe_never_split_set_.insert(token.value());
if (!bpe_encoder_.contains(token.value())) {
added_tokens_encoder.insert(
token.value(), bpe_encoder_.size() + added_tokens_encoder.size());
added_tokens_encoder_.insert(
token.value(), bpe_encoder_.size() + added_tokens_encoder_.size());
added_tokens_decoder_.insert(
bpe_decoder_.size() + added_tokens_decoder_.size(), token.value());
newly_added++;
}
}

// Loop for any additional tokens
for (auto const& token : additional_special_tokens) {
if (added_tokens_encoder.contains(token))
if (added_tokens_encoder_.contains(token))
continue;
bpe_never_split_set_.insert(token);
if (!bpe_encoder_.contains(token)) {
added_tokens_encoder.insert(
token, bpe_encoder_.size() + added_tokens_encoder.size());
added_tokens_encoder_.insert(
token, bpe_encoder_.size() + added_tokens_encoder_.size());
added_tokens_decoder_.insert(
bpe_decoder_.size() + added_tokens_decoder_.size(), token);
newly_added++;
}
}
Expand Down
3 changes: 2 additions & 1 deletion torchtext/csrc/gpt2_bpe_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct GPT2BPEEncoder : torch::CustomClassHolder {
std::string token,
bool is_never_split_token);
int64_t GetBPEMergeRank_(std::string pair);
c10::Dict<std::string, int64_t> added_tokens_encoder;
c10::Dict<std::string, int64_t> added_tokens_encoder_;
c10::Dict<int64_t, std::string> added_tokens_decoder_;

protected:
c10::Dict<std::string, std::vector<std::string>> cache_;
Expand Down