From d8443a74c8d4ad34556ac0b74f97d5fc510f1520 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 27 Sep 2024 12:56:57 -0600 Subject: [PATCH 1/6] feat(tokenizer): Add an abstract base class for additional tokenizer support Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- tokenizer/base.py | 32 ++++++++++++++++++++++++++++++++ tokenizer/tiktoken.py | 4 +++- 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 tokenizer/base.py diff --git a/tokenizer/base.py b/tokenizer/base.py new file mode 100644 index 000000000..75998b32b --- /dev/null +++ b/tokenizer/base.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Abstract base class for all tokenizer classes in python matching c++ interface. +""" + +# Standard +from abc import ABC, abstractmethod +from typing import List + + +class TokenizerBase(ABC): + __doc__ = __doc__ + + @abstractmethod + def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]: + """Encode the given string and optionally include bos/eos tokens""" + + @abstractmethod + def decode(self, ids: List[int]) -> str: + """Decode the given token ids into a string""" + + @abstractmethod + def bos_id(self) -> int: + """The id of the begin-of-string token""" + + @abstractmethod + def eos_id(self) -> int: + """The id of the end-of-string token""" diff --git a/tokenizer/tiktoken.py b/tokenizer/tiktoken.py index 9e9fe2264..30eb98624 100644 --- a/tokenizer/tiktoken.py +++ b/tokenizer/tiktoken.py @@ -23,6 +23,8 @@ import tiktoken from tiktoken.load import load_tiktoken_bpe +from .base import TokenizerBase + logger = getLogger(__name__) @@ -38,7 +40,7 @@ class Message(TypedDict): Dialog = Sequence[Message] -class Tokenizer: +class Tokenizer(TokenizerBase): """ tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ From 2483486336b32770276c6b16bb5e02ff1959c6c4 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 27 Sep 2024 12:58:38 -0600 Subject: [PATCH 2/6] feat(tokenizers): Add a python impl of the Tokenizer interface using tokenizers This allows for all HF tokenizers to be supported in the python layer. It will need significant work to offer similar compatibility at the c++ layer. Signed-off-by: Gabe Goodhart --- tokenizer/tokenizers.py | 64 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tokenizer/tokenizers.py diff --git a/tokenizer/tokenizers.py b/tokenizer/tokenizers.py new file mode 100644 index 000000000..c42c9987a --- /dev/null +++ b/tokenizer/tokenizers.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Standard +from typing import List +import json + +# Third Party +from tokenizers import Tokenizer + +# Local +from .base import TokenizerBase + + +class TokenizersTokenizer(TokenizerBase): + """ + Wrapper around the `tokenizers` library for API compatibility + """ + + def __init__(self, file_path: str): + self._tokenizer = Tokenizer.from_file(file_path) + # The BOS and EOS tokens are not easily visible from the tokenizer + # object itself, so we extract them at construction with a sample call + self._bos_token = self._tokenizer.encode("Test", add_special_tokens=True).ids[0] + # There is no explicit BOS token in many tokenizers, so we look for a + # single special token that most resembles the BOS token. + self._eos_token = None + tok_content = json.loads(self._tokenizer.to_str()) + end_toks = [ + tok for tok in tok_content['added_tokens'] + if tok["special"] and "end" in tok["content"] + ] + assert end_toks, "Unable to find an EOS token in the added tokens" + if len(end_toks) > 1: + end_text_toks = [ + tok for tok in end_toks if "text" in tok["content"] + ] + if len(end_text_toks) == 1: + self._eos_token = end_text_toks[0]["id"] + assert self._eos_token is not None, "Unable to find an EOS token in the added tokens" + + def encode( + self, + s: str, + *, + bos: bool = False, + eos: bool = False, + ) -> List[int]: + res = self._tokenizer.encode(s, add_special_tokens=bos).ids + if eos and (not res or res[-1] != self._eos_token): + res.append(self._eos_token) + return res + + def decode(self, ids: List[int]) -> str: + return self._tokenizer.decode(ids) + + def bos_id(self) -> int: + return self._bos_token + + def eos_id(self) -> int: + return self._eos_token From 5c4101582dd864e39a43e4464fa02fd9735141d7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 27 Sep 2024 16:17:19 -0600 Subject: [PATCH 3/6] feat(builder): Add support for using the TokenizersTokenizer in builder Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/cli/builder.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 511cf1f35..5d1083771 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -204,6 +204,7 @@ class TokenizerArgs: tokenizer_path: Optional[Union[Path, str]] = None is_sentencepiece: bool = False is_tiktoken: bool = False + is_tokenizers: bool = False t: Optional[Any] = None def __post_init__(self): @@ -213,6 +214,7 @@ def __post_init__(self): self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) self.is_tiktoken = True self.is_sentencepiece = False + self.is_tokenizers = False return except: pass @@ -223,12 +225,25 @@ def __post_init__(self): self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) self.is_tiktoken = False self.is_sentencepiece = True + self.is_tokenizers = False + return + except: + pass + + try: + from tokenizer.tokenizers import TokenizersTokenizer + + self.t = TokenizersTokenizer(str(self.tokenizer_path)) + self.is_tiktoken = False + self.is_sentencepiece = False + self.is_tokenizers = True return except: pass self.is_tiktoken = False self.is_sentencepiece = False + self.is_tokenizers = False self.t = None return @@ -240,16 +255,27 @@ def validate_model( if model is None: return - if self.is_tiktoken == self.is_sentencepiece: + if len(list(filter(lambda x: x, [self.is_tiktoken, self.is_tokenizers, self.is_sentencepiece]))) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece + is_tokenizers = self.is_tokenizers use_tiktoken = model.config.use_tiktoken + use_tokenizers = model.config.use_tokenizers + use_sentencepiece = not (use_tiktoken or use_tokenizers) - if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): + if ( + (is_tiktoken and not use_tiktoken) or + (is_tokenizers and not use_tokenizers) or + (is_sentencepiece and not use_sentencepiece) + ): raise RuntimeError( - f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}" + "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( + tokenizer_setting_to_name(use_tiktoken, use_tokenizers), + tokenizer_setting_to_name(is_tiktoken, is_tokenizers), + model_description, + ) ) return @@ -605,5 +631,9 @@ def _initialize_model( return model -def tokenizer_setting_to_name(tiktoken: bool = False) -> str: - return "TikToken" if tiktoken else "SentencePiece" +def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: + if tiktoken: + return "TikToken" + if tokenizers: + return "Tokenizers" + return "SentencePiece" \ No newline at end of file From 27d27087be5cf28fddea2f241b67f828d87f6d39 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 1 Oct 2024 16:44:31 -0600 Subject: [PATCH 4/6] feat(tokenizers): Add and plumb the option to use the "tokenizers" tokenizer Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchchat/model.py b/torchchat/model.py index 7868b6593..b6d8232a2 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -270,7 +270,9 @@ class TransformerArgs: norm_eps: float = 1e-5 multiple_of: int = 256 ffn_dim_multiplier: Optional[int] = None + # Select the desired tokenizer. Defaults to sentencepiece use_tiktoken: bool = False + use_tokenizers: bool = False max_seq_length: int = 8192 rope_scaling: Optional[Dict[str, Any]] = None # For pipeline parallel @@ -327,12 +329,14 @@ class ModelArgs: model_type: ModelType transformer_args: Dict[str, Dict[str, Any]] use_tiktoken: bool + use_tokenizers: bool def __init__( self, transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, use_tiktoken: bool = False, + use_tokenizers: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) @@ -341,6 +345,7 @@ def __init__( # Model-level attributes self.use_tiktoken = use_tiktoken + self.use_tokenizers = use_tokenizers def _sanity_check( self, @@ -367,7 +372,8 @@ def from_params(cls, params_path): } use_tiktoken = loaded_params.get("use_tiktoken", False) - return cls(transformer_args, model_type, use_tiktoken) + use_tokenizers = loaded_params.get("use_tokenizers", False) + return cls(transformer_args, model_type, use_tiktoken, use_tokenizers) @classmethod def from_table(cls, name: str): From 9d9a4a7fbebfc3784fc137a5c4e3f72a3d00a0e3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 27 Sep 2024 16:37:30 -0600 Subject: [PATCH 5/6] fix(tokenizers): Fix how bos/eos tokens are parsed from tokenizers (lib) Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- tokenizer/tokenizers.py | 72 ++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/tokenizer/tokenizers.py b/tokenizer/tokenizers.py index c42c9987a..1eb300e60 100644 --- a/tokenizer/tokenizers.py +++ b/tokenizer/tokenizers.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. # Standard -from typing import List +from typing import List, Optional import json +import os # Third Party from tokenizers import Tokenizer @@ -21,26 +22,53 @@ class TokenizersTokenizer(TokenizerBase): """ def __init__(self, file_path: str): - self._tokenizer = Tokenizer.from_file(file_path) - # The BOS and EOS tokens are not easily visible from the tokenizer - # object itself, so we extract them at construction with a sample call - self._bos_token = self._tokenizer.encode("Test", add_special_tokens=True).ids[0] - # There is no explicit BOS token in many tokenizers, so we look for a - # single special token that most resembles the BOS token. - self._eos_token = None - tok_content = json.loads(self._tokenizer.to_str()) - end_toks = [ - tok for tok in tok_content['added_tokens'] - if tok["special"] and "end" in tok["content"] - ] - assert end_toks, "Unable to find an EOS token in the added tokens" - if len(end_toks) > 1: - end_text_toks = [ - tok for tok in end_toks if "text" in tok["content"] + # If the path is a directory, look for "tokenizer.json" which is + # standard for transformers checkpoints and also look for the + # "tokenizer_config.json" file to parse eos/bos tokens + if os.path.isdir(file_path): + tokenizer_path = os.path.join(file_path, "tokenizer.json") + tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json") + else: + tokenizer_path = file_path + tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json") + if not os.path.isfile(tokenizer_path): + tokenizer_config_path = None + + # Load the tokenizer itself + self._tokenizer = Tokenizer.from_file(tokenizer_path) + + # If available, parse bos/eos tokens from the tokenizer config + self._bos_id, self._eos_id = None, None + if tokenizer_config_path is not None: + with open(tokenizer_config_path, "r") as handle: + tok_config = json.load(handle) + bos_token = tok_config.get("bos_token") + eos_token = tok_config.get("eos_token") + if bos_token is not None: + self._bos_id = self._tokenizer.token_to_id(bos_token) + if eos_token is not None: + self._eos_id = self._tokenizer.token_to_id(eos_token) + + # If no eos/bos tokens found, go looking for them! + if None in [self._bos_id, self._eos_id]: + tok_content = json.loads(self._tokenizer.to_str()) + if self._bos_id is None: + self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"]) + if self._eos_id is None: + self._eos_id = self._look_for_special_token(tok_content, ["end", "text"]) + + assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens" + + @staticmethod + def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]: + candidate_toks = added_tokens + for search_str in search_strs: + candidate_toks = [ + tok for tok in candidate_toks + if tok["special"] and search_str in tok["content"] ] - if len(end_text_toks) == 1: - self._eos_token = end_text_toks[0]["id"] - assert self._eos_token is not None, "Unable to find an EOS token in the added tokens" + if len(candidate_toks) == 1: + return candidate_toks[0]["id"] def encode( self, @@ -58,7 +86,7 @@ def decode(self, ids: List[int]) -> str: return self._tokenizer.decode(ids) def bos_id(self) -> int: - return self._bos_token + return self._bos_id def eos_id(self) -> int: - return self._eos_token + return self._eos_id From 4a20f69976d820139b44655c99a4bdac8ec674ba Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 24 Oct 2024 16:28:32 -0600 Subject: [PATCH 6/6] fix(hf_tokenizer): Rename to HFTokenizer and corresponding flags https://github.com/pytorch/torchchat/issues/1251 Branch: TokenizersTokenizer-1251 Co-Authored-By: jackkhuu@fb.com Signed-off-by: Gabe Goodhart --- tokenizer/{tokenizers.py => hf_tokenizer.py} | 4 +-- torchchat/cli/builder.py | 28 ++++++++++---------- torchchat/model.py | 12 ++++----- 3 files changed, 22 insertions(+), 22 deletions(-) rename tokenizer/{tokenizers.py => hf_tokenizer.py} (96%) diff --git a/tokenizer/tokenizers.py b/tokenizer/hf_tokenizer.py similarity index 96% rename from tokenizer/tokenizers.py rename to tokenizer/hf_tokenizer.py index 1eb300e60..7ad5807d1 100644 --- a/tokenizer/tokenizers.py +++ b/tokenizer/hf_tokenizer.py @@ -16,9 +16,9 @@ from .base import TokenizerBase -class TokenizersTokenizer(TokenizerBase): +class HFTokenizer(TokenizerBase): """ - Wrapper around the `tokenizers` library for API compatibility + Wrapper around the Huggingface `tokenizers` library for API compatibility """ def __init__(self, file_path: str): diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 5d1083771..0fd9c58b9 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -204,7 +204,7 @@ class TokenizerArgs: tokenizer_path: Optional[Union[Path, str]] = None is_sentencepiece: bool = False is_tiktoken: bool = False - is_tokenizers: bool = False + is_hf_tokenizer: bool = False t: Optional[Any] = None def __post_init__(self): @@ -214,7 +214,7 @@ def __post_init__(self): self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path)) self.is_tiktoken = True self.is_sentencepiece = False - self.is_tokenizers = False + self.is_hf_tokenizer = False return except: pass @@ -225,25 +225,25 @@ def __post_init__(self): self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path)) self.is_tiktoken = False self.is_sentencepiece = True - self.is_tokenizers = False + self.is_hf_tokenizer = False return except: pass try: - from tokenizer.tokenizers import TokenizersTokenizer + from tokenizer.hf_tokenizer import HFTokenizer - self.t = TokenizersTokenizer(str(self.tokenizer_path)) + self.t = HFTokenizer(str(self.tokenizer_path)) self.is_tiktoken = False self.is_sentencepiece = False - self.is_tokenizers = True + self.is_hf_tokenizer = True return except: pass self.is_tiktoken = False self.is_sentencepiece = False - self.is_tokenizers = False + self.is_hf_tokenizer = False self.t = None return @@ -255,25 +255,25 @@ def validate_model( if model is None: return - if len(list(filter(lambda x: x, [self.is_tiktoken, self.is_tokenizers, self.is_sentencepiece]))) != 1: + if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1: raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}") is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - is_tokenizers = self.is_tokenizers + is_hf_tokenizer = self.is_hf_tokenizer use_tiktoken = model.config.use_tiktoken - use_tokenizers = model.config.use_tokenizers - use_sentencepiece = not (use_tiktoken or use_tokenizers) + use_hf_tokenizer = model.config.use_hf_tokenizer + use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) if ( (is_tiktoken and not use_tiktoken) or - (is_tokenizers and not use_tokenizers) or + (is_hf_tokenizer and not use_hf_tokenizer) or (is_sentencepiece and not use_sentencepiece) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( - tokenizer_setting_to_name(use_tiktoken, use_tokenizers), - tokenizer_setting_to_name(is_tiktoken, is_tokenizers), + tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer), + tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer), model_description, ) ) diff --git a/torchchat/model.py b/torchchat/model.py index b6d8232a2..11f3dc167 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -272,7 +272,7 @@ class TransformerArgs: ffn_dim_multiplier: Optional[int] = None # Select the desired tokenizer. Defaults to sentencepiece use_tiktoken: bool = False - use_tokenizers: bool = False + use_hf_tokenizer: bool = False max_seq_length: int = 8192 rope_scaling: Optional[Dict[str, Any]] = None # For pipeline parallel @@ -329,14 +329,14 @@ class ModelArgs: model_type: ModelType transformer_args: Dict[str, Dict[str, Any]] use_tiktoken: bool - use_tokenizers: bool + use_hf_tokenizer: bool def __init__( self, transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, use_tiktoken: bool = False, - use_tokenizers: bool = False, + use_hf_tokenizer: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) @@ -345,7 +345,7 @@ def __init__( # Model-level attributes self.use_tiktoken = use_tiktoken - self.use_tokenizers = use_tokenizers + self.use_hf_tokenizer = use_hf_tokenizer def _sanity_check( self, @@ -372,8 +372,8 @@ def from_params(cls, params_path): } use_tiktoken = loaded_params.get("use_tiktoken", False) - use_tokenizers = loaded_params.get("use_tokenizers", False) - return cls(transformer_args, model_type, use_tiktoken, use_tokenizers) + use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False) + return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer) @classmethod def from_table(cls, name: str):