From ccc3f79be531834a81ca2cd0de3f44c1b1444778 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 22 Feb 2022 12:17:25 -0800 Subject: [PATCH] Modify CLIPTokenizer to either infer number of merges from encoder json or take it in constructor (#1622) --- test/test_transforms.py | 94 +++++++++++++++++++++++++++++++++++------ torchtext/transforms.py | 60 ++++++++++++++++++-------- 2 files changed, 122 insertions(+), 32 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 3324b5efe9..45c7e90c97 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -294,13 +294,22 @@ def test_gpt2_bpe_tokenizer_save_load_torchscript(self): class TestCLIPTokenizer(TorchtextTestCase): - def _load_tokenizer(self, test_scripting): + def _load_tokenizer(self, init_using_merge_only: bool, test_scripting: bool): encoder_json = "clip_encoder.json" bpe_vocab = "clip_vocab.bpe" - tokenizer = transforms.CLIPTokenizer( - encoder_json_path=get_asset_path(encoder_json), - vocab_bpe_path=get_asset_path(bpe_vocab), - ) + num_merges = ( + 49152 - 256 - 2 + ) # https://github.com/mlfoundations/open_clip/blob/57b3e8ea6ad6bfc2974203945f8fd577e0659468/src/clip/tokenizer.py#L67 + if init_using_merge_only: + tokenizer = transforms.CLIPTokenizer( + merges_path=get_asset_path(bpe_vocab), + num_merges=num_merges, + ) + else: + tokenizer = transforms.CLIPTokenizer( + encoder_json_path=get_asset_path(encoder_json), + merges_path=get_asset_path(bpe_vocab), + ) if test_scripting: tokenizer = torch.jit.script(tokenizer) return tokenizer @@ -308,12 +317,67 @@ def _load_tokenizer(self, test_scripting): def _clip_tokenizer(self, tokenizer): sample_texts = [ "Hello World!, how are you?", - "<|startoftext|> the quick brown fox jumped over the lazy dog <|endoftext|>" + "<|startoftext|> the quick brown fox jumped over the lazy dog <|endoftext|>", + "Awaiting their due award... Photo by Frederick (FN) Noronha. Copyleft. Creative Commons 3.0. Non-commercial. Attribution. May be copied for non-commercial purposes. For other purposes, contact fn at goa-india.org", ] expected_token_ids = [ - ['3306', '1002', '29325', '829', '631', '592', '286'], - ['49406', '518', '3712', '2866', '3240', '16901', '962', '518', '10753', '1929', '49407'], + ["3306", "1002", "29325", "829", "631", "592", "286"], + ["49406", "518", "3712", "2866", "3240", "16901", "962", "518", "10753", "1929", "49407"], + [ + "14872", + "911", + "2887", + "2047", + "678", + "1125", + "638", + "18570", + "263", + "21763", + "264", + "1062", + "521", + "1429", + "269", + "11376", + "1823", + "269", + "4450", + "16653", + "274", + "269", + "271", + "269", + "3353", + "268", + "6287", + "269", + "24624", + "740", + "269", + "1270", + "655", + "36770", + "556", + "3353", + "268", + "6287", + "22020", + "269", + "556", + "1010", + "22020", + "267", + "3523", + "21763", + "536", + "14399", + "268", + "1762", + "269", + "5593", + ], ] # test batch of sentences @@ -325,22 +389,24 @@ def _clip_tokenizer(self, tokenizer): def test_clip_tokenizer(self): """test tokenization on single sentence input as well as batch on sentences""" - self._clip_tokenizer(self._load_tokenizer(test_scripting=False)) + self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=True, test_scripting=False)) + self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=False, test_scripting=False)) def test_clip_tokenizer_jit(self): """test tokenization with scripting on single sentence input as well as batch on sentences""" - self._clip_tokenizer(self._load_tokenizer(test_scripting=True)) + self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=True, test_scripting=True)) + self._clip_tokenizer(self._load_tokenizer(init_using_merge_only=False, test_scripting=True)) def test_clip_tokenizer_save_load_pybind(self): - tokenizer = self._load_tokenizer(test_scripting=False) - tokenizer_path = os.path.join(self.test_dir, 'gpt2_tokenizer_pybind.pt') + tokenizer = self._load_tokenizer(init_using_merge_only=True, test_scripting=False) + tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_pybind.pt") torch.save(tokenizer, tokenizer_path) loaded_tokenizer = torch.load(tokenizer_path) self._clip_tokenizer((loaded_tokenizer)) def test_clip_tokenizer_save_load_torchscript(self): - tokenizer = self._load_tokenizer(test_scripting=False) - tokenizer_path = os.path.join(self.test_dir, 'gpt2_tokenizer_torchscript.pt') + tokenizer = self._load_tokenizer(init_using_merge_only=True, test_scripting=False) + tokenizer_path = os.path.join(self.test_dir, "gpt2_tokenizer_torchscript.pt") # Call the __prepare_scriptable__() func and convert the building block to the torbhind version # Not expect users to use the torchbind version on eager mode but still need a CI test here. torch.save(tokenizer.__prepare_scriptable__(), tokenizer_path) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 3a91af3963..1f9365874d 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -321,31 +321,55 @@ class CLIPTokenizer(Module): (a bit like sentencepiece) so a word will be encoded differently whether it is at the beginning of the sentence (without space) or not. - :param encoder_json_path: Path to BPE encoder json file. + The below code snippet shows how to use the CLIP tokenizer with encoder and merges file + taken from the original paper implementation. + + Example + >>> from torchtext.transforms import CLIPTokenizer + >>> MERGES_FILE = "http://download.pytorch.org/models/text/clip_merges.bpe" + >>> ENCODER_FILE = "http://download.pytorch.org/models/text/clip_encoder.json" + >>> tokenizer = CLIPTokenizer(merges_path=MERGES_FILE, encoder_json_path=ENCODER_FILE) + >>> tokenizer("the quick brown fox jumped over the lazy dog") + + :param merges_path: Path to bpe merges file. + :type merges_path: str + :param encoder_json_path: Optional, path to BPE encoder json file. When specified, this is used + to infer num_merges. :type encoder_json_path: str - :param vocab_bpe_path: Path to bpe vocab file. - :type vocab_bpe_path: str + :param num_merges: Optional, number of merges to read from the bpe merges file. + :type num_merges: int """ _seperator: torch.jit.Final[str] - def __init__( - self, - encoder_json_path: str, - vocab_bpe_path: str, - ): + def __init__(self, merges_path: str, encoder_json_path: Optional[str] = None, num_merges: Optional[int] = None): super().__init__() self._seperator = "\u0001" - # load bpe encoder - with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f: - bpe_encoder = json.load(f) - # load bpe vocab - with open(get_asset_local_path(vocab_bpe_path), "r", encoding="utf-8") as f: - bpe_vocab = f.read() - bpe_merge_ranks = { - self._seperator.join(merge_pair.split()): i - for i, merge_pair in enumerate(bpe_vocab.split("\n")[1:-1]) - } + # load bpe merges + with open(get_asset_local_path(merges_path), "r", encoding="utf-8") as f: + bpe_merges = f.read().split("\n")[1:] + + if encoder_json_path: + # load bpe encoder + with open(get_asset_local_path(encoder_json_path), "r", encoding="utf-8") as f: + bpe_encoder = json.load(f) + # 256 * 2 for each byte. For each byte we have ['a', 'a'] + # Additional 2 tokens for bos and eos + num_merges = len(bpe_encoder) - (256 * 2 + 2) + bpe_merge_ranks = { + self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges]) + } + else: + num_merges = num_merges or len(bpe_merges) + bpe_merge_ranks = { + self._seperator.join(merge_pair.split()): i for i, merge_pair in enumerate(bpe_merges[:num_merges]) + } + bpe_vocab = list(bytes_to_unicode().values()) + bpe_vocab = bpe_vocab + [v + "" for v in bpe_vocab] + bpe_vocab.extend(["".join(merge_pair.split()) for merge_pair in bpe_merges[:num_merges]]) + bpe_vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + bpe_encoder = {v: i for i, v in enumerate(bpe_vocab)} + # Caching is enabled in Eager mode self.bpe = CLIPEncoderPyBind(bpe_encoder, bpe_merge_ranks, self._seperator, bytes_to_unicode(), True)