Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9e66291

Browse files
committed
move SPECIAL_TOKENS_ATTRIBUTES to utils
1 parent 42a14a0 commit 9e66291

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

torchtext/transforms.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from torchtext._torchtext import RegexTokenizer as RegexTokenizerPybind
1616
from torchtext.data.functional import load_sp_model
17-
from torchtext.utils import get_asset_local_path
17+
from torchtext.utils import get_asset_local_path, SPECIAL_TOKENS_ATTRIBUTES
1818
from torchtext.vocab import Vocab
1919

2020
from . import functional as F
@@ -294,16 +294,6 @@ class GPT2BPETokenizer(Module):
294294
def __init__(self, encoder_json_path: str, vocab_bpe_path: str, return_tokens: bool = False) -> None:
295295
super().__init__()
296296
self._seperator = "\u0001"
297-
self.SPECIAL_TOKENS_ATTRIBUTES = [
298-
"bos_token",
299-
"eos_token",
300-
"unk_token",
301-
"sep_token",
302-
"pad_token",
303-
"cls_token",
304-
"mask_token",
305-
"additional_special_tokens",
306-
]
307297
# load bpe encoder and bpe decoder
308298
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
309299
bpe_encoder = json.load(f)
@@ -371,8 +361,8 @@ def add_special_tokens(self, special_tokens_dict: Mapping[str, Union[str, Sequen
371361
"""
372362
for key in special_tokens_dict.keys():
373363
assert (
374-
key in self.SPECIAL_TOKENS_ATTRIBUTES
375-
), f"Key '{key}' is not in the special token list: {self.SPECIAL_TOKENS_ATTRIBUTES}"
364+
key in SPECIAL_TOKENS_ATTRIBUTES
365+
), f"Key '{key}' is not in the special token list: {SPECIAL_TOKENS_ATTRIBUTES}"
376366

377367
return self.bpe.add_special_tokens(
378368
{k: v for k, v in special_tokens_dict.items() if k != "additional_special_tokens"},

torchtext/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16+
SPECIAL_TOKENS_ATTRIBUTES = [
17+
"bos_token",
18+
"eos_token",
19+
"unk_token",
20+
"sep_token",
21+
"pad_token",
22+
"cls_token",
23+
"mask_token",
24+
"additional_special_tokens",
25+
]
26+
1627

1728
def reporthook(t):
1829
"""

0 commit comments

Comments
 (0)