Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ac53d38

Browse files
committed
update decoding logic to handle special tokens
1 parent 3f9c349 commit ac53d38

File tree

3 files changed

+186
-19
lines changed

3 files changed

+186
-19
lines changed

test/torchtext_unittest/test_transforms.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,140 @@ def _gpt2_bpe_decoder(self, tokenizer):
560560
for idx, ids in enumerate(sample_ids):
561561
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])
562562

563+
def _gpt2_bpe_decoder_with_special_tokens(self, tokenizer):
564+
sample_ids = [
565+
[
566+
"27",
567+
"91",
568+
"437",
569+
"1659",
570+
"5239",
571+
"91",
572+
"29",
573+
"290",
574+
"1279",
575+
"91",
576+
"437",
577+
"1659",
578+
"5239",
579+
"91",
580+
"29",
581+
"389",
582+
"2041",
583+
"1279",
584+
"91",
585+
"437",
586+
"1659",
587+
"1370",
588+
"91",
589+
"29",
590+
"318",
591+
"407",
592+
"0",
593+
],
594+
[
595+
"9288",
596+
"15859",
597+
"8905",
598+
"51",
599+
"1279",
600+
"615",
601+
"603",
602+
"62",
603+
"4658",
604+
"29",
605+
"351",
606+
"27196",
607+
"24027",
608+
"1279",
609+
"91",
610+
"437",
611+
"1659",
612+
"5239",
613+
"91",
614+
"29",
615+
"290",
616+
"8005",
617+
"62",
618+
"44710",
619+
],
620+
["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"],
621+
[
622+
"40",
623+
"423",
624+
"281",
625+
"16882",
626+
"1359",
627+
"428",
628+
"318",
629+
"257",
630+
"1332",
631+
"1279",
632+
"91",
633+
"437",
634+
"1659",
635+
"5239",
636+
"91",
637+
"29",
638+
],
639+
]
640+
641+
expected_texts = [
642+
"<|endoftext|> and <|endoftext|> are special <|endofline|> is not!",
643+
"test ACCEPT <avail_actions> with DECLINE <|endoftext|> and NO_ACTION",
644+
"Avdija Vršajević în",
645+
"I have an inkling this is a test <|endoftext|>",
646+
]
647+
648+
for idx, ids in enumerate(sample_ids):
649+
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])
650+
651+
newly_added = tokenizer.add_special_tokens(
652+
special_tokens_dict={
653+
"unk_token": "<|endoftext|>",
654+
"sep_token": "<avail_actions>",
655+
"additional_special_tokens": [
656+
"ACCEPT",
657+
"DECLINE",
658+
"inkling",
659+
],
660+
}
661+
)
662+
self.assertEqual(newly_added, 4)
663+
664+
sample_ids = [
665+
[
666+
"50256",
667+
"392",
668+
"50256",
669+
"533",
670+
"2041",
671+
"1279",
672+
"91",
673+
"437",
674+
"1659",
675+
"1370",
676+
"91",
677+
"29",
678+
"318",
679+
"407",
680+
"0",
681+
],
682+
["9288", "50258", "50257", "4480", "50259", "50256", "392", "8005", "62", "44710"],
683+
["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"],
684+
["40", "423", "281", "50260", "5661", "318", "257", "1332", "50256"],
685+
]
686+
687+
expected_texts = [
688+
"<|endoftext|> and <|endoftext|> are special <|endofline|> is not!",
689+
"test ACCEPT <avail_actions> with DECLINE <|endoftext|> and NO_ACTION",
690+
"Avdija Vršajević în",
691+
"I have an inkling this is a test <|endoftext|>",
692+
]
693+
694+
for idx, ids in enumerate(sample_ids):
695+
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])
696+
563697
@nested_params([True, False], [True, False])
564698
def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
565699
"""test tokenization on single sentence input as well as batch on sentences"""
@@ -568,6 +702,7 @@ def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
568702
def test_gpt2_bpe_decoder(self):
569703
"""test string output returned by decoder given the token ids"""
570704
self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False))
705+
self._gpt2_bpe_decoder_with_special_tokens(self._load_tokenizer(test_scripting=False, return_tokens=False))
571706

572707
@nested_params([True, False])
573708
def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens):

torchtext/csrc/gpt2_bpe_tokenizer.cpp

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ std::vector<std::string> GPT2BPEEncoder::PreTokenize_(std::string input) {
381381
std::vector<int64_t> GPT2BPEEncoder::Encode(const std::string& text) {
382382
std::vector<int64_t> bpe_token_ids;
383383
for (const auto& token : PreTokenize_(text)) {
384-
if (added_tokens_encoder.contains(token)) {
385-
bpe_token_ids.push_back(added_tokens_encoder.at(token));
384+
if (added_tokens_encoder_.contains(token)) {
385+
bpe_token_ids.push_back(added_tokens_encoder_.at(token));
386386
continue;
387387
}
388388
bool is_never_split_token =
@@ -397,18 +397,45 @@ std::vector<int64_t> GPT2BPEEncoder::Encode(const std::string& text) {
397397

398398
std::string GPT2BPEEncoder::Decode(const std::vector<int64_t>& tokens) {
399399
std::string text;
400+
std::vector<bool> special_token_flags(tokens.size());
400401
// setup converter for converting wide chars to/from chars
401402
using convert_type = std::codecvt_utf8<wchar_t>;
402403
std::wstring_convert<convert_type, wchar_t> converter;
403404

404-
for (const auto token : tokens) {
405-
// get unicode string for given integer key
406-
const std::string str = bpe_decoder_.at(token);
407-
const std::wstring ws = converter.from_bytes(str);
408-
for (wchar_t wchr : ws) {
409-
// get output character from byte decoder for each wide character
410-
unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr));
411-
text.push_back(uchr);
405+
for (int tok_idx = 0; tok_idx < tokens.size(); tok_idx++) {
406+
const auto token = tokens[tok_idx];
407+
std::string decoded_token;
408+
409+
if (added_tokens_decoder_.contains(token)) {
410+
// string is a special token from extended vocab
411+
decoded_token = added_tokens_decoder_.at(token);
412+
special_token_flags[tok_idx] = true;
413+
} else {
414+
const std::string str = bpe_decoder_.at(token);
415+
if (bpe_never_split_set_.find(str) != bpe_never_split_set_.end()) {
416+
// string is a special token from known vocab
417+
decoded_token = str;
418+
special_token_flags[tok_idx] = true;
419+
} else {
420+
// string is a regular token from known vocab
421+
const std::wstring ws = converter.from_bytes(str);
422+
for (wchar_t wchr : ws) {
423+
// get output character from byte decoder for each wide character
424+
unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr));
425+
decoded_token.push_back(uchr);
426+
}
427+
}
428+
}
429+
430+
// fix left space(s) for special tokens
431+
if (special_token_flags[tok_idx] == true &&
432+
(tok_idx > 0 && special_token_flags[tok_idx - 1] == false)) {
433+
text.push_back(' ');
434+
}
435+
text.append(decoded_token);
436+
// fix right space(s) for special tokens
437+
if (special_token_flags[tok_idx] == true && tok_idx != tokens.size() - 1) {
438+
text.push_back(' ');
412439
}
413440
}
414441
return text;
@@ -433,30 +460,34 @@ int64_t GPT2BPEEncoder::AddSpecialTokens(
433460
int64_t newly_added = 0;
434461

435462
/* All special tokens get added to `bpe_never_split_set_` set to avoid being
436-
* split during tokenization. Tokens are added to `added_tokens_encoder` only
437-
* if they are not already known (i.e. present in `bpe_encoder_`).
463+
* split during tokenization. Tokens are added to `added_tokens_encoder_` only
464+
* if they are not already known (i.e. not already present in `bpe_encoder_`).
438465
*/
439466

440467
// Loop for standard tokens such as "bos_token", "eos_token", etc.
441468
for (auto const& token : standard_special_tokens_dict) {
442-
if (added_tokens_encoder.contains(token.value()))
469+
if (added_tokens_encoder_.contains(token.value()))
443470
continue;
444471
bpe_never_split_set_.insert(token.value());
445472
if (!bpe_encoder_.contains(token.value())) {
446-
added_tokens_encoder.insert(
447-
token.value(), bpe_encoder_.size() + added_tokens_encoder.size());
473+
added_tokens_encoder_.insert(
474+
token.value(), bpe_encoder_.size() + added_tokens_encoder_.size());
475+
added_tokens_decoder_.insert(
476+
bpe_decoder_.size() + added_tokens_decoder_.size(), token.value());
448477
newly_added++;
449478
}
450479
}
451480

452481
// Loop for any additional tokens
453482
for (auto const& token : additional_special_tokens) {
454-
if (added_tokens_encoder.contains(token))
483+
if (added_tokens_encoder_.contains(token))
455484
continue;
456485
bpe_never_split_set_.insert(token);
457486
if (!bpe_encoder_.contains(token)) {
458-
added_tokens_encoder.insert(
459-
token, bpe_encoder_.size() + added_tokens_encoder.size());
487+
added_tokens_encoder_.insert(
488+
token, bpe_encoder_.size() + added_tokens_encoder_.size());
489+
added_tokens_decoder_.insert(
490+
bpe_decoder_.size() + added_tokens_decoder_.size(), token);
460491
newly_added++;
461492
}
462493
}

torchtext/csrc/gpt2_bpe_tokenizer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ struct GPT2BPEEncoder : torch::CustomClassHolder {
6262
std::string token,
6363
bool is_never_split_token);
6464
int64_t GetBPEMergeRank_(std::string pair);
65-
c10::Dict<std::string, int64_t> added_tokens_encoder;
65+
c10::Dict<std::string, int64_t> added_tokens_encoder_;
66+
c10::Dict<int64_t, std::string> added_tokens_decoder_;
6667

6768
protected:
6869
c10::Dict<std::string, std::vector<std::string>> cache_;

0 commit comments

Comments
 (0)