Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,26 +294,90 @@ def test_gpt2_bpe_tokenizer_save_load_torchscript(self):


class TestCLIPTokenizer(TorchtextTestCase):
def _load_tokenizer(self, test_scripting):
def _load_tokenizer(self, init_using_merge_only: bool, test_scripting: bool):
encoder_json = "clip_encoder.json"
bpe_vocab = "clip_vocab.bpe"
tokenizer = transforms.CLIPTokenizer(
encoder_json_path=get_asset_path(encoder_json),
vocab_bpe_path=get_asset_path(bpe_vocab),
)
num_merges = (
49152 - 256 - 2
) # https://github.com/mlfoundations/open_clip/blob/57b3e8ea6ad6bfc2974203945f8fd577e0659468/src/clip/tokenizer.py#L67
if init_using_merge_only:
tokenizer = transforms.CLIPTokenizer(
merges_path=get_asset_path(bpe_vocab),
num_merges=num_merges,
)
else:
tokenizer = transforms.CLIPTokenizer(
encoder_json_path=get_asset_path(encoder_json),
merges_path=get_asset_path(bpe_vocab),
)
if test_scripting:
tokenizer = torch.jit.script(tokenizer)
return tokenizer

def _clip_tokenizer(self, tokenizer):
sample_texts = [
"Hello World!, how are you?",
"<|startoftext|> the quick brown fox jumped over the lazy dog <|endoftext|>"
"<|startoftext|> the quick brown fox jumped over the lazy dog <|endoftext|>",
"Awaiting their due award... Photo by Frederick (FN) Noronha. Copyleft. Creative Commons 3.0. Non-commercial. Attribution. May be copied for non-commercial purposes. For other purposes, contact fn at goa-india.org",
]

expected_token_ids = [
['3306', '1002', '29325', '829', '631', '592', '286'],
['49406', '518', '3712', '2866', '3240', '16901', '962', '518', '10753', '1929', '49407'],
["3306", "1002", "29325", "829", "631", "592", "286"],
["49406", "518", "3712", "2866", "3240", "16901", "962", "518", "10753", "1929", "49407"],
[
"14872",
"911",
"2887",
"2047",
"678",
"1125",
"638",
"18570",
"263",
"21763",
"264",
"1062",
"521",
"1429",
"269",
"11376",
"1823",
"269",
"4450",
"16653",
"274",
"269",
"271",
"269",
"3353",
"268",
"6287",
"269",
"24624",
"740",
"269",
"1270",
"655",
"36770",
"556",
"3353",
"268",
"6287",
"22020",
"269",
"556",
"1010",
"22020",
"267",
"3523",
"21763",
"536",
"14399",
"268",
"1762",
"269",
"5593",
],
]

# test batch of sentences
Expand All @@ -325,22 +389,24 @@ def _clip_tokenizer(self, tokenizer):

def test_clip_tokenizer(self):
"""test tokenization on single sentence input as well as batch on sentences"""
self._clip_tokenizer(self._load_tokenizer(test_scripting=False))
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=True, test_scripting=False))
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=False, test_scripting=False))

def test_clip_tokenizer_jit(self):
"""test tokenization with scripting on single sentence input as well as batch on sentences"""
self._clip_tokenizer(self._load_tokenizer(test_scripting=True))
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=True, test_scripting=True))
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=False, test_scripting=True))

def test_clip_tokenizer_save_load_pybind(self):
tokenizer = self._load_tokenizer(test_scripting=False)
tokenizer_path = os.path.join(self.test_dir, 'gpt2_tokenizer_pybind.pt')
tokenizer = self._load_tokenizer(init_using_merge_only=True, test_scripting=False)
tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt")
torch.save(tokenizer, tokenizer_path)
loaded_tokenizer = torch.load(tokenizer_path)
self._clip_tokenizer((loaded_tokenizer))

def test_clip_tokenizer_save_load_torchscript(self):
tokenizer = self._load_tokenizer(test_scripting=False)
tokenizer_path = os.path.join(self.test_dir, 'gpt2_tokenizer_torchscript.pt')
tokenizer = self._load_tokenizer(init_using_merge_only=True, test_scripting=False)
tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_torchscript.pt")
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
torch.save(tokenizer.__prepare_scriptable__(), tokenizer_path)
Expand Down
60 changes: 42 additions & 18 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,31 +321,55 @@ class CLIPTokenizer(Module):
(a bit like sentencepiece) so a word will be encoded differently whether it
is at the beginning of the sentence (without space) or not.

:param encoder_json_path: Path to BPE encoder json file.
The below code snippet shows how to use the CLIP tokenizer with encoder and merges file
taken from the original paper implementation.

Example
>>> from torchtext.transforms import CLIPTokenizer
>>> MERGES_FILE = "http://download.pytorch.org/models/text/clip_merges.bpe"
>>> ENCODER_FILE = "http://download.pytorch.org/models/text/clip_encoder.json"
>>> tokenizer = CLIPTokenizer(merges_path=MERGES_FILE, encoder_json_path=ENCODER_FILE)
>>> tokenizer("the quick brown fox jumped over the lazy dog")

:param merges_path: Path to bpe merges file.
:type merges_path: str
:param encoder_json_path: Optional, path to BPE encoder json file. When specified, this is used
to infer num_merges.
:type encoder_json_path: str
:param vocab_bpe_path: Path to bpe vocab file.
:type vocab_bpe_path: str
:param num_merges: Optional, number of merges to read from the bpe merges file.
:type num_merges: int
"""

_seperator: torch.jit.Final[str]

def __init__(
self,
encoder_json_path: str,
vocab_bpe_path: str,
):
def __init__(self, merges_path: str, encoder_json_path: Optional[str] = None, num_merges: Optional[int] = None):
super().__init__()
self._seperator = "\u0001"
# load bpe encoder
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
bpe_encoder = json.load(f)
# load bpe vocab
with open(get_asset_local_path(vocab_bpe_path), "r", encoding="utf-8") as f:
bpe_vocab = f.read()
bpe_merge_ranks = {
self._seperator.join(merge_pair.split()): i
for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1])
}
# load bpe merges
with open(get_asset_local_path(merges_path), "r", encoding="utf-8") as f:
bpe_merges = f.read().split("\n")[1:]

if encoder_json_path:
# load bpe encoder
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
bpe_encoder = json.load(f)
# 256 * 2 for each byte. For each byte we have ['a', 'a</w>']
# Additional 2 tokens for bos and eos
num_merges = len(bpe_encoder) - (256 * 2 + 2)
bpe_merge_ranks = {
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges])
}
else:
num_merges = num_merges or len(bpe_merges)
bpe_merge_ranks = {
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges])
}
bpe_vocab = list(bytes_to_unicode().values())
bpe_vocab = bpe_vocab + [v + "</w>" for v in bpe_vocab]
bpe_vocab.extend(["".join(merge_pair.split()) for merge_pair in bpe_merges[:num_merges]])
bpe_vocab.extend(["<|startoftext|>", "<|endoftext|>"])
bpe_encoder = {v: i for i, v in enumerate(bpe_vocab)}

# Caching is enabled in Eager mode
self.bpe = CLIPEncoderPyBind(bpe_encoder, bpe_merge_ranks,
self._seperator, bytes_to_unicode(), True)
Expand Down