Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b0df58b

Browse files
reachsumitSumit Kumar
andauthored
Add never_split feature to BERTTokenizer (#1898)
* Add never_split feature to BERTTokenizer * fix logical operator * move set creation to BERTEncoder constructor Co-authored-by: Sumit Kumar <[email protected]>
1 parent 67d2692 commit b0df58b

File tree

6 files changed

+222
-89
lines changed

6 files changed

+222
-89
lines changed

test/torchtext_unittest/test_transforms.py

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from collections import OrderedDict
3+
from typing import List, Optional
34
from unittest.mock import patch
45

56
import torch
@@ -586,7 +587,9 @@ def test_clip_tokenizer_save_load_torchscript(self) -> None:
586587

587588

588589
class TestBERTTokenizer(TorchtextTestCase):
589-
def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_tokens: bool):
590+
def _load_tokenizer(
591+
self, test_scripting: bool, do_lower_case: bool, return_tokens: bool, never_split: Optional[List[str]] = None
592+
):
590593
if do_lower_case:
591594
vocab_file = "bert_base_uncased_vocab.txt"
592595
else:
@@ -596,46 +599,117 @@ def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_toke
596599
vocab_path=get_asset_path(vocab_file),
597600
do_lower_case=do_lower_case,
598601
return_tokens=return_tokens,
602+
never_split=never_split,
599603
)
600604
if test_scripting:
601605
tokenizer = torch.jit.script(tokenizer)
602606
return tokenizer
603607

604-
def _bert_tokenizer(self, tokenizer, do_lower_case):
608+
def _bert_tokenizer(self, tokenizer, do_lower_case, never_split: Optional[List[str]] = None):
605609
sample_texts = [
606610
"Hello World!, how are you?",
607611
"Hélló WoŕlḊ¿",
608612
"Respublica superiorem",
609613
"Avdija Vršajević în",
614+
" \tHeLLo!how \n Are yoU? [UNK]",
615+
"hi world [UNK] [CLS]",
616+
"testing, [UNK] words! [SEP]",
610617
]
611618

612-
if do_lower_case:
613-
expected_tokens = [
614-
["hello", "world", "!", ",", "how", "are", "you", "?"],
615-
["hello", "world", "¿"],
616-
["res", "##pu", "##bl", "##ica", "superior", "##em"],
617-
["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"],
618-
]
619-
expected_token_ids = [
620-
["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"],
621-
["7592", "2088", "1094"],
622-
["24501", "14289", "16558", "5555", "6020", "6633"],
623-
["20704", "4305", "3900", "27830", "3736", "26782", "1999"],
624-
]
619+
if not never_split:
620+
if do_lower_case:
621+
expected_tokens = [
622+
["hello", "world", "!", ",", "how", "are", "you", "?"],
623+
["hello", "world", "¿"],
624+
["res", "##pu", "##bl", "##ica", "superior", "##em"],
625+
["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"],
626+
["hello", "!", "how", "are", "you", "?", "[", "un", "##k", "]"],
627+
["hi", "world", "[", "un", "##k", "]", "[", "cl", "##s", "]"],
628+
["testing", ",", "[", "un", "##k", "]", "words", "!", "[", "sep", "]"],
629+
]
630+
expected_token_ids = [
631+
["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"],
632+
["7592", "2088", "1094"],
633+
["24501", "14289", "16558", "5555", "6020", "6633"],
634+
["20704", "4305", "3900", "27830", "3736", "26782", "1999"],
635+
["7592", "999", "2129", "2024", "2017", "1029", "1031", "4895", "2243", "1033"],
636+
["7632", "2088", "1031", "4895", "2243", "1033", "1031", "18856", "2015", "1033"],
637+
["5604", "1010", "1031", "4895", "2243", "1033", "2616", "999", "1031", "19802", "1033"],
638+
]
625639

640+
else:
641+
expected_tokens = [
642+
["Hello", "World", "!", ",", "how", "are", "you", "?"],
643+
["H", "##é", "##ll", "##ó", "[UNK]", "¿"],
644+
["Re", "##sp", "##ub", "##lica", "superior", "##em"],
645+
["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"],
646+
["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?", "[", "UN", "##K", "]"],
647+
["hi", "world", "[", "UN", "##K", "]", "[", "C", "##LS", "]"],
648+
["testing", ",", "[", "UN", "##K", "]", "words", "!", "[", "SE", "##P", "]"],
649+
]
650+
expected_token_ids = [
651+
["8667", "1291", "106", "117", "1293", "1132", "1128", "136"],
652+
["145", "2744", "2339", "7774", "100", "225"],
653+
["11336", "20080", "10354", "9538", "7298", "5521"],
654+
["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"],
655+
[
656+
"1124",
657+
"23955",
658+
"1186",
659+
"106",
660+
"1293",
661+
"2372",
662+
"26063",
663+
"2591",
664+
"136",
665+
"164",
666+
"7414",
667+
"2428",
668+
"166",
669+
],
670+
["20844", "1362", "164", "7414", "2428", "166", "164", "140", "15928", "166"],
671+
["5193", "117", "164", "7414", "2428", "166", "1734", "106", "164", "12342", "2101", "166"],
672+
]
626673
else:
627-
expected_tokens = [
628-
["Hello", "World", "!", ",", "how", "are", "you", "?"],
629-
["H", "##é", "##ll", "##ó", "[UNK]", "¿"],
630-
["Re", "##sp", "##ub", "##lica", "superior", "##em"],
631-
["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"],
632-
]
633-
expected_token_ids = [
634-
["8667", "1291", "106", "117", "1293", "1132", "1128", "136"],
635-
["145", "2744", "2339", "7774", "100", "225"],
636-
["11336", "20080", "10354", "9538", "7298", "5521"],
637-
["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"],
638-
]
674+
if do_lower_case:
675+
expected_tokens = [
676+
["hello", "world", "!", ",", "how", "are", "you", "?"],
677+
["hello", "world", "¿"],
678+
["res", "##pu", "##bl", "##ica", "superior", "##em"],
679+
["av", "##di", "##ja", "vr", "##sa", "##jevic", "in"],
680+
["hello", "!", "how", "are", "you", "?", "[UNK]"],
681+
["hi", "world", "[UNK]", "[CLS]"],
682+
["testing", ",", "[UNK]", "words", "!", "[", "sep", "]"],
683+
]
684+
expected_token_ids = [
685+
["7592", "2088", "999", "1010", "2129", "2024", "2017", "1029"],
686+
["7592", "2088", "1094"],
687+
["24501", "14289", "16558", "5555", "6020", "6633"],
688+
["20704", "4305", "3900", "27830", "3736", "26782", "1999"],
689+
["7592", "999", "2129", "2024", "2017", "1029", "100"],
690+
["7632", "2088", "100", "101"],
691+
["5604", "1010", "100", "2616", "999", "1031", "19802", "1033"],
692+
]
693+
694+
else:
695+
expected_tokens = [
696+
["Hello", "World", "!", ",", "how", "are", "you", "?"],
697+
["H", "##é", "##ll", "##ó", "[UNK]", "¿"],
698+
["Re", "##sp", "##ub", "##lica", "superior", "##em"],
699+
["A", "##v", "##di", "##ja", "V", "##r", "##ša", "##je", "##vić", "î", "##n"],
700+
["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?", "[UNK]"],
701+
["hi", "world", "[UNK]", "[CLS]"],
702+
["testing", ",", "[UNK]", "words", "!", "[", "SE", "##P", "]"],
703+
]
704+
expected_token_ids = [
705+
["8667", "1291", "106", "117", "1293", "1132", "1128", "136"],
706+
["145", "2744", "2339", "7774", "100", "225"],
707+
["11336", "20080", "10354", "9538", "7298", "5521"],
708+
["138", "1964", "3309", "3174", "159", "1197", "23834", "5561", "10225", "260", "1179"],
709+
["1124", "23955", "1186", "106", "1293", "2372", "26063", "2591", "136", "100"],
710+
["20844", "1362", "100", "101"],
711+
["5193", "117", "100", "1734", "106", "164", "12342", "2101", "166"],
712+
]
639713

640714
# test batch of sentences
641715
if tokenizer._return_tokens:
@@ -650,14 +724,18 @@ def _bert_tokenizer(self, tokenizer, do_lower_case):
650724
else:
651725
self.assertEqual(tokenizer(txt), expected_token_ids[idx])
652726

653-
@nested_params([True, False], [True, False], [True, False])
654-
def test_bert_tokenizer(self, test_scripting, do_lower_case, return_tokens):
727+
@nested_params([True, False], [True, False], [True, False], [[], None, ["[UNK]", "[CLS]"]])
728+
def test_bert_tokenizer(self, test_scripting, do_lower_case, return_tokens, never_split):
655729
"""test tokenization on single sentence input as well as batch on sentences"""
656730
self._bert_tokenizer(
657731
self._load_tokenizer(
658-
test_scripting=test_scripting, do_lower_case=do_lower_case, return_tokens=return_tokens
732+
test_scripting=test_scripting,
733+
do_lower_case=do_lower_case,
734+
return_tokens=return_tokens,
735+
never_split=never_split,
659736
),
660737
do_lower_case=do_lower_case,
738+
never_split=never_split,
661739
)
662740

663741
@nested_params([True, False], [True, False], [True, False])

torchtext/csrc/bert_tokenizer.cpp

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -132,38 +132,52 @@ static std::string _convert_from_unicode(const UString& text) {
132132
return ret;
133133
}
134134

135-
static void to_lower(UString& text) {
136-
for (size_t i = 0; i < text.size(); i++) {
137-
text[i] = utf8proc_tolower(text[i]);
135+
static void to_lower(UString& token) {
136+
for (size_t i = 0; i < token.size(); i++) {
137+
token[i] = utf8proc_tolower(token[i]);
138138
}
139139
}
140140

141141
BERTEncoder::BERTEncoder(
142142
const std::string& vocab_file,
143143
bool do_lower_case,
144-
c10::optional<bool> strip_accents)
144+
c10::optional<bool> strip_accents,
145+
std::vector<std::string> never_split)
145146
: vocab_{_read_vocab(vocab_file)},
146147
do_lower_case_{do_lower_case},
147-
strip_accents_{strip_accents} {}
148+
strip_accents_{strip_accents},
149+
never_split_{never_split} {
150+
never_split_set_.insert(never_split_.begin(), never_split_.end());
151+
}
148152

149153
BERTEncoder::BERTEncoder(
150154
Vocab vocab,
151155
bool do_lower_case,
152-
c10::optional<bool> strip_accents)
156+
c10::optional<bool> strip_accents,
157+
std::vector<std::string> never_split)
153158
: vocab_{vocab},
154159
do_lower_case_{do_lower_case},
155-
strip_accents_{strip_accents} {}
160+
strip_accents_{strip_accents},
161+
never_split_{never_split} {
162+
never_split_set_.insert(never_split_.begin(), never_split_.end());
163+
}
156164

157-
UString BERTEncoder::_clean(const UString& text, bool strip_accents) {
165+
UString BERTEncoder::_clean(
166+
const UString& token,
167+
bool strip_accents,
168+
bool is_never_split_token) {
158169
/* This function combines:
159170
* cleaning
160171
* strip accents
161172
*/
162-
size_t len = text.size();
173+
size_t len = token.size();
163174
UString ret;
164175
for (size_t i = 0; i < len; i++) {
165-
uint32_t c = text[i];
166-
if (c == 0 || c == 0xFFFD || _is_control(c) ||
176+
uint32_t c = token[i];
177+
if (c == 0 || c == 0xFFFD || _is_control(c)) {
178+
continue;
179+
}
180+
if ((!is_never_split_token) &&
167181
(utf8proc_category(c) == UTF8PROC_CATEGORY_MN && strip_accents)) {
168182
continue;
169183
}
@@ -221,18 +235,20 @@ void BERTEncoder::_max_seg(
221235
}
222236
}
223237

224-
UString BERTEncoder::_basic_tokenize(const UString& text) {
238+
UString BERTEncoder::_basic_tokenize(
239+
const UString& token,
240+
bool is_never_split_token) {
225241
/*
226242
This function enables white space based tokenization for following:
227243
* chinese character
228244
* punctuation
229245
*/
230246

231247
UString ret;
232-
size_t len = text.size();
248+
size_t len = token.size();
233249
for (size_t i = 0; i < len; i++) {
234-
uint32_t c = text[i];
235-
if (_is_chinese_char(c) || _is_punct_char(c)) {
250+
uint32_t c = token[i];
251+
if (_is_chinese_char(c) || (_is_punct_char(c) && !is_never_split_token)) {
236252
if (!ret.empty() && ret.back() != ' ') {
237253
ret.append(1, ' ');
238254
}
@@ -254,51 +270,56 @@ UString BERTEncoder::_basic_tokenize(const UString& text) {
254270

255271
std::vector<std::string> BERTEncoder::Tokenize(std::string text) {
256272
std::vector<std::string> results;
273+
std::vector<std::string> interim_results;
274+
std::vector<std::string> tokens;
257275

258-
// normalize
276+
// split based on whitespace
277+
split_(text, tokens);
259278

260-
bool strip_accents = do_lower_case_;
279+
for (auto& token : tokens) {
280+
bool is_never_split_token =
281+
never_split_set_.find(token) != never_split_set_.end();
261282

262-
if (strip_accents_.has_value()) {
263-
strip_accents = strip_accents_.has_value();
264-
}
283+
// normalize
265284

266-
if (strip_accents) {
267-
char* nfkcstr = reinterpret_cast<char*>(
268-
utf8proc_NFD(reinterpret_cast<const unsigned char*>(text.c_str())));
269-
if (nfkcstr == nullptr) {
270-
return {};
271-
}
285+
bool strip_accents = do_lower_case_;
272286

273-
text.assign(nfkcstr, strlen(nfkcstr));
287+
if (strip_accents_.has_value()) {
288+
strip_accents = strip_accents_.has_value();
289+
}
274290

275-
free(nfkcstr);
276-
}
291+
if (strip_accents) {
292+
char* nfkcstr = reinterpret_cast<char*>(
293+
utf8proc_NFD(reinterpret_cast<const unsigned char*>(token.c_str())));
294+
if (nfkcstr == nullptr) {
295+
return {};
296+
}
277297

278-
// convert to unicode codepoints
279-
UString unicodes = _convert_to_unicode(text);
298+
token.assign(nfkcstr, strlen(nfkcstr));
280299

281-
// clean -> invalid character removal, whitespce cleanup, strip accents
282-
unicodes = _clean(unicodes, strip_accents);
300+
free(nfkcstr);
301+
}
283302

284-
// Add whitespace in front/back of tokens to enable splitting based on
285-
// white-space Enables tokenization on chinese characters, Punctuations
286-
unicodes = _basic_tokenize(unicodes);
303+
// convert to unicode codepoints
304+
UString unicodes = _convert_to_unicode(token);
287305

288-
// Convert text to lower-case
289-
if (do_lower_case_)
290-
to_lower(unicodes);
306+
// clean -> invalid character removal, whitespce cleanup, strip accents
307+
unicodes = _clean(unicodes, strip_accents, is_never_split_token);
291308

292-
// Convert back to string from code-points
293-
std::string newtext = _convert_from_unicode(unicodes);
309+
// Add whitespace in front/back of tokens to enable splitting based on
310+
// white-space Enables tokenization on chinese characters, Punctuations
311+
unicodes = _basic_tokenize(unicodes, is_never_split_token);
294312

295-
std::vector<std::string> tokens;
313+
// Convert token to lower-case
314+
if (do_lower_case_ && !is_never_split_token)
315+
to_lower(unicodes);
296316

297-
// split based on whitespace
298-
split_(newtext, tokens);
317+
// Convert back to string from code-points
318+
split_(_convert_from_unicode(unicodes), interim_results);
319+
}
299320

300321
// Perform WORDPIECE tokenization
301-
for (auto s : tokens) {
322+
for (auto s : interim_results) {
302323
if (s.size() > kMaxCharsPerWords) {
303324
results.push_back(kUnkToken);
304325
} else {
@@ -338,16 +359,20 @@ std::vector<std::vector<int64_t>> BERTEncoder::BatchEncode(
338359
BERTEncoderStates _serialize_bert_encoder(
339360
const c10::intrusive_ptr<BERTEncoder>& self) {
340361
return std::make_tuple(
341-
self->do_lower_case_, self->strip_accents_, self->vocab_.itos_);
362+
self->do_lower_case_,
363+
self->strip_accents_,
364+
self->never_split_,
365+
self->vocab_.itos_);
342366
}
343367

344368
c10::intrusive_ptr<BERTEncoder> _deserialize_bert_encoder(
345369
BERTEncoderStates states) {
346370
auto do_lower_case = std::get<0>(states);
347371
auto strip_accents = std::get<1>(states);
348-
auto strings = std::get<2>(states);
372+
auto never_split = std::get<2>(states);
373+
auto strings = std::get<3>(states);
349374
return c10::make_intrusive<BERTEncoder>(
350-
Vocab(std::move(strings)), do_lower_case, strip_accents);
375+
Vocab(std::move(strings)), do_lower_case, strip_accents, never_split);
351376
}
352377

353378
} // namespace torchtext

0 commit comments

Comments
 (0)