Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/torchtext_unittest/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,33 @@ def _gpt2_bpe_tokenizer(self, tokenizer):
else:
self.assertEqual(tokenizer(txt), expected_token_ids[idx])

def _gpt2_bpe_decoder(self, tokenizer):
sample_ids = [
["15496", "2159", "28265", "703", "389", "345", "30"],
["39", "2634", "297", "10205", "220", "22173", "129", "243", "75", "41585", "232", "126", "123"],
["4965", "11377", "64", "2208", "72", "29625"],
["7355", "67", "34655", "569", "81", "32790", "1228", "1990", "72", "38325", "6184", "106", "77"],
]

expected_texts = [
"Hello World!, how are you?",
"Hélló WoŕlḊ¿",
"Respublica superiorem",
"Avdija Vršajević în",
]

for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

@nested_params([True, False], [True, False])
def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
"""test tokenization on single sentence input as well as batch on sentences"""
self._gpt2_bpe_tokenizer(self._load_tokenizer(test_scripting=test_scripting, return_tokens=return_tokens))

def test_gpt2_bpe_decoder(self):
"""test string output returned by decoder given the token ids"""
self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False))

def test_gpt2_bpe_tokenizer_save_load_pybind(self) -> None:
tokenizer = self._load_tokenizer(test_scripting=False, return_tokens=False)
tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt")
Expand Down
29 changes: 28 additions & 1 deletion torchtext/csrc/gpt2_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <torchtext/csrc/regex.h> // @manual

#include <algorithm>
#include <codecvt>
#include <locale>
#include <sstream>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -147,7 +149,13 @@ GPT2BPEEncoder::GPT2BPEEncoder(
bpe_merge_ranks_(std::move(bpe_merge_ranks)),
byte_encoder_(std::move(byte_encoder)),
seperator_(std::move(seperator)),
caching_enabled_(caching_enabled) {}
caching_enabled_(caching_enabled) {
for (auto const& x : bpe_encoder_)
bpe_decoder_.insert(x.value(), x.key());

for (auto const& x : byte_encoder_)
byte_decoder_.insert(x.value(), x.key());
}

GPT2BPEEncoder::GPT2BPEEncoder(
const std::unordered_map<std::string, int64_t>& bpe_encoder,
Expand Down Expand Up @@ -279,6 +287,25 @@ std::vector<int64_t> GPT2BPEEncoder::Encode(const std::string& text) {
return bpe_token_ids;
}

std::string GPT2BPEEncoder::Decode(const std::vector<int64_t>& tokens) {
std::string text;
// setup converter for converting wide chars to/from chars
using convert_type = std::codecvt_utf8<wchar_t>;
std::wstring_convert<convert_type, wchar_t> converter;

for (const auto token : tokens) {
// get unicode string for given integer key
const std::string str = bpe_decoder_.at(token);
const std::wstring ws = converter.from_bytes(str);
for (wchar_t wchr : ws) {
// get output character from byte decoder for each wide character
unsigned char uchr = byte_decoder_.at(converter.to_bytes(wchr));
text.push_back(uchr);
}
}
return text;
}

std::vector<std::string> GPT2BPEEncoder::Tokenize(const std::string& text) {
std::vector<std::string> bpe_tokens;
for (const auto& token : PreTokenize_(text)) {
Expand Down
3 changes: 3 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ struct GPT2BPEEncoder : torch::CustomClassHolder {

public:
const c10::Dict<std::string, int64_t> bpe_encoder_;
const c10::Dict<int64_t, std::string> bpe_decoder_;
const c10::Dict<std::string, int64_t> bpe_merge_ranks_;
const c10::Dict<int64_t, std::string> byte_encoder_;
const c10::Dict<std::string, int64_t> byte_decoder_;
const std::string seperator_;
const bool caching_enabled_;
explicit GPT2BPEEncoder(
Expand Down Expand Up @@ -99,6 +101,7 @@ struct GPT2BPEEncoder : torch::CustomClassHolder {
// --> result --> [707, 5927, 11, 707, 68]
//
TORCHTEXT_API std::vector<int64_t> Encode(const std::string& text);
TORCHTEXT_API std::string Decode(const std::vector<int64_t>& tokens);
TORCHTEXT_API std::vector<std::string> Tokenize(const std::string& text);

TORCHTEXT_API std::unordered_map<std::string, int64_t> GetBPEEncoder() const;
Expand Down
1 change: 1 addition & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ PYBIND11_MODULE(_torchtext, m) {
.def_property_readonly("byte_encoder_", &GPT2BPEEncoder::GetByteEncoder)
.def("encode", &GPT2BPEEncoder::Encode)
.def("tokenize", &GPT2BPEEncoder::Tokenize)
.def("decode", &GPT2BPEEncoder::Decode)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<GPT2BPEEncoder>& self)
Expand Down
1 change: 1 addition & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
c10::Dict<int64_t, std::string>,
bool>())
.def("encode", &GPT2BPEEncoder::Encode)
.def("decode", &GPT2BPEEncoder::Decode)
.def("tokenize", &GPT2BPEEncoder::Tokenize)
.def_pickle(
// __getstate__
Expand Down
10 changes: 10 additions & 0 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,16 @@ def __prepare_scriptable__(self):
return tokenizer_copy
return self

def decode(self, tokens: List[str]) -> str:
"""Return a decoded string given a list of string token ids.

:param input: A list of strings, each string corresponds to token ids.
:type input: List[str]
:return: decoded text
:rtype: str
"""
return self.bpe.decode([int(token) for token in tokens])


class CLIPTokenizer(Module):
"""
Expand Down