Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 170b74a

Browse files
authored
Modify CLIPTokenizer to either infer number of merges from encoder json or take it in constructor (#1622) (#1626)
1 parent 73941c6 commit 170b74a

File tree

2 files changed

+122
-32
lines changed

2 files changed

+122
-32
lines changed

test/test_transforms.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -294,26 +294,90 @@ def test_gpt2_bpe_tokenizer_save_load_torchscript(self):
294294

295295

296296
class TestCLIPTokenizer(TorchtextTestCase):
297-
def _load_tokenizer(self, test_scripting):
297+
def _load_tokenizer(self, init_using_merge_only: bool, test_scripting: bool):
298298
encoder_json = "clip_encoder.json"
299299
bpe_vocab = "clip_vocab.bpe"
300-
tokenizer = transforms.CLIPTokenizer(
301-
encoder_json_path=get_asset_path(encoder_json),
302-
vocab_bpe_path=get_asset_path(bpe_vocab),
303-
)
300+
num_merges = (
301+
49152 - 256 - 2
302+
) # https://github.com/mlfoundations/open_clip/blob/57b3e8ea6ad6bfc2974203945f8fd577e0659468/src/clip/tokenizer.py#L67
303+
if init_using_merge_only:
304+
tokenizer = transforms.CLIPTokenizer(
305+
merges_path=get_asset_path(bpe_vocab),
306+
num_merges=num_merges,
307+
)
308+
else:
309+
tokenizer = transforms.CLIPTokenizer(
310+
encoder_json_path=get_asset_path(encoder_json),
311+
merges_path=get_asset_path(bpe_vocab),
312+
)
304313
if test_scripting:
305314
tokenizer = torch.jit.script(tokenizer)
306315
return tokenizer
307316

308317
def _clip_tokenizer(self, tokenizer):
309318
sample_texts = [
310319
"Hello World!, how are you?",
311-
"<|startoftext|> the quick brown fox jumped over the lazy dog <|endoftext|>"
320+
"<|startoftext|> the quick brown fox jumped over the lazy dog <|endoftext|>",
321+
"Awaiting their due award... Photo by Frederick (FN) Noronha. Copyleft. Creative Commons 3.0. Non-commercial. Attribution. May be copied for non-commercial purposes. For other purposes, contact fn at goa-india.org",
312322
]
313323

314324
expected_token_ids = [
315-
['3306', '1002', '29325', '829', '631', '592', '286'],
316-
['49406', '518', '3712', '2866', '3240', '16901', '962', '518', '10753', '1929', '49407'],
325+
["3306", "1002", "29325", "829", "631", "592", "286"],
326+
["49406", "518", "3712", "2866", "3240", "16901", "962", "518", "10753", "1929", "49407"],
327+
[
328+
"14872",
329+
"911",
330+
"2887",
331+
"2047",
332+
"678",
333+
"1125",
334+
"638",
335+
"18570",
336+
"263",
337+
"21763",
338+
"264",
339+
"1062",
340+
"521",
341+
"1429",
342+
"269",
343+
"11376",
344+
"1823",
345+
"269",
346+
"4450",
347+
"16653",
348+
"274",
349+
"269",
350+
"271",
351+
"269",
352+
"3353",
353+
"268",
354+
"6287",
355+
"269",
356+
"24624",
357+
"740",
358+
"269",
359+
"1270",
360+
"655",
361+
"36770",
362+
"556",
363+
"3353",
364+
"268",
365+
"6287",
366+
"22020",
367+
"269",
368+
"556",
369+
"1010",
370+
"22020",
371+
"267",
372+
"3523",
373+
"21763",
374+
"536",
375+
"14399",
376+
"268",
377+
"1762",
378+
"269",
379+
"5593",
380+
],
317381
]
318382

319383
# test batch of sentences
@@ -325,22 +389,24 @@ def _clip_tokenizer(self, tokenizer):
325389

326390
def test_clip_tokenizer(self):
327391
"""test tokenization on single sentence input as well as batch on sentences"""
328-
self._clip_tokenizer(self._load_tokenizer(test_scripting=False))
392+
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=True, test_scripting=False))
393+
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=False, test_scripting=False))
329394

330395
def test_clip_tokenizer_jit(self):
331396
"""test tokenization with scripting on single sentence input as well as batch on sentences"""
332-
self._clip_tokenizer(self._load_tokenizer(test_scripting=True))
397+
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=True, test_scripting=True))
398+
self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=False, test_scripting=True))
333399

334400
def test_clip_tokenizer_save_load_pybind(self):
335-
tokenizer = self._load_tokenizer(test_scripting=False)
336-
tokenizer_path = os.path.join(self.test_dir, 'gpt2_tokenizer_pybind.pt')
401+
tokenizer = self._load_tokenizer(init_using_merge_only=True, test_scripting=False)
402+
tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt")
337403
torch.save(tokenizer, tokenizer_path)
338404
loaded_tokenizer = torch.load(tokenizer_path)
339405
self._clip_tokenizer((loaded_tokenizer))
340406

341407
def test_clip_tokenizer_save_load_torchscript(self):
342-
tokenizer = self._load_tokenizer(test_scripting=False)
343-
tokenizer_path = os.path.join(self.test_dir, 'gpt2_tokenizer_torchscript.pt')
408+
tokenizer = self._load_tokenizer(init_using_merge_only=True, test_scripting=False)
409+
tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_torchscript.pt")
344410
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
345411
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
346412
torch.save(tokenizer.__prepare_scriptable__(), tokenizer_path)

torchtext/transforms.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,31 +321,55 @@ class CLIPTokenizer(Module):
321321
(a bit like sentencepiece) so a word will be encoded differently whether it
322322
is at the beginning of the sentence (without space) or not.
323323
324-
:param encoder_json_path: Path to BPE encoder json file.
324+
The below code snippet shows how to use the CLIP tokenizer with encoder and merges file
325+
taken from the original paper implementation.
326+
327+
Example
328+
>>> from torchtext.transforms import CLIPTokenizer
329+
>>> MERGES_FILE = "http://download.pytorch.org/models/text/clip_merges.bpe"
330+
>>> ENCODER_FILE = "http://download.pytorch.org/models/text/clip_encoder.json"
331+
>>> tokenizer = CLIPTokenizer(merges_path=MERGES_FILE, encoder_json_path=ENCODER_FILE)
332+
>>> tokenizer("the quick brown fox jumped over the lazy dog")
333+
334+
:param merges_path: Path to bpe merges file.
335+
:type merges_path: str
336+
:param encoder_json_path: Optional, path to BPE encoder json file. When specified, this is used
337+
to infer num_merges.
325338
:type encoder_json_path: str
326-
:param vocab_bpe_path: Path to bpe vocab file.
327-
:type vocab_bpe_path: str
339+
:param num_merges: Optional, number of merges to read from the bpe merges file.
340+
:type num_merges: int
328341
"""
329342

330343
_seperator: torch.jit.Final[str]
331344

332-
def __init__(
333-
self,
334-
encoder_json_path: str,
335-
vocab_bpe_path: str,
336-
):
345+
def __init__(self, merges_path: str, encoder_json_path: Optional[str] = None, num_merges: Optional[int] = None):
337346
super().__init__()
338347
self._seperator = "\u0001"
339-
# load bpe encoder
340-
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
341-
bpe_encoder = json.load(f)
342-
# load bpe vocab
343-
with open(get_asset_local_path(vocab_bpe_path), "r", encoding="utf-8") as f:
344-
bpe_vocab = f.read()
345-
bpe_merge_ranks = {
346-
self._seperator.join(merge_pair.split()): i
347-
for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1])
348-
}
348+
# load bpe merges
349+
with open(get_asset_local_path(merges_path), "r", encoding="utf-8") as f:
350+
bpe_merges = f.read().split("\n")[1:]
351+
352+
if encoder_json_path:
353+
# load bpe encoder
354+
with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f:
355+
bpe_encoder = json.load(f)
356+
# 256 * 2 for each byte. For each byte we have ['a', 'a</w>']
357+
# Additional 2 tokens for bos and eos
358+
num_merges = len(bpe_encoder) - (256 * 2 + 2)
359+
bpe_merge_ranks = {
360+
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges])
361+
}
362+
else:
363+
num_merges = num_merges or len(bpe_merges)
364+
bpe_merge_ranks = {
365+
self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges])
366+
}
367+
bpe_vocab = list(bytes_to_unicode().values())
368+
bpe_vocab = bpe_vocab + [v + "</w>" for v in bpe_vocab]
369+
bpe_vocab.extend(["".join(merge_pair.split()) for merge_pair in bpe_merges[:num_merges]])
370+
bpe_vocab.extend(["<|startoftext|>", "<|endoftext|>"])
371+
bpe_encoder = {v: i for i, v in enumerate(bpe_vocab)}
372+
349373
# Caching is enabled in Eager mode
350374
self.bpe = CLIPEncoderPyBind(bpe_encoder, bpe_merge_ranks,
351375
self._seperator, bytes_to_unicode(), True)

0 commit comments

Comments
 (0)