88import os
99import sys
1010from dataclasses import dataclass
11+ from enum import Enum
1112from pathlib import Path
1213from typing import Any , Dict , Optional , Tuple , Union
1314
@@ -237,23 +238,24 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
237238 speculative_builder_args .pte_path = None
238239 return speculative_builder_args
239240
241+ class TokenizerType (Enum ):
242+ NONE = 0
243+ TIKTOKEN = 1
244+ SENTENCEPIECE = 2
245+ HF_TOKENIZER = 3
240246
241247@dataclass
242248class TokenizerArgs :
243249 tokenizer_path : Optional [Union [Path , str ]] = None
244- is_sentencepiece : bool = False
245- is_tiktoken : bool = False
246- is_hf_tokenizer : bool = False
250+ tokenizer_type : TokenizerType = TokenizerType .NONE
247251 t : Optional [Any ] = None
248252
249253 def __post_init__ (self ):
250254 try :
251255 from tokenizer .tiktoken import Tokenizer as TiktokenTokenizer
252256
253257 self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
254- self .is_tiktoken = True
255- self .is_sentencepiece = False
256- self .is_hf_tokenizer = False
258+ self .tokenizer_type = TokenizerType .TIKTOKEN
257259 return
258260 except :
259261 pass
@@ -262,9 +264,7 @@ def __post_init__(self):
262264 from sentencepiece import SentencePieceProcessor
263265
264266 self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
265- self .is_tiktoken = False
266- self .is_sentencepiece = True
267- self .is_hf_tokenizer = False
267+ self .tokenizer_type = TokenizerType .SENTENCEPIECE
268268 return
269269 except :
270270 pass
@@ -273,18 +273,19 @@ def __post_init__(self):
273273 from tokenizer .hf_tokenizer import HFTokenizer
274274
275275 self .t = HFTokenizer (str (self .tokenizer_path ))
276- self .is_tiktoken = False
277- self .is_sentencepiece = False
278- self .is_hf_tokenizer = True
276+ self .tokenizer_type = TokenizerType .HF_TOKENIZER
279277 return
280278 except :
281279 pass
282280
283- self .is_tiktoken = False
284- self .is_sentencepiece = False
285- self .is_hf_tokenizer = False
286- self .t = None
287- return
281+ def is_tiktoken (self ) -> bool :
282+ return self .tokenizer_type == TokenizerType .TIKTOKEN
283+
284+ def is_sentencepiece (self ) -> bool :
285+ return self .tokenizer_type == TokenizerType .SENTENCEPIECE
286+
287+ def is_hf_tokenizer (self ) -> bool :
288+ return self .tokenizer_type == TokenizerType .HF_TOKENIZER
288289
289290 def validate_model (
290291 self ,
@@ -294,12 +295,13 @@ def validate_model(
294295 if model is None :
295296 return
296297
297- if sum ([ self .is_tiktoken , self . is_hf_tokenizer , self . is_sentencepiece ]) != 1 :
298+ if self .tokenizer_type == TokenizerType . NONE :
298299 raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
299300
300- is_tiktoken = self .is_tiktoken
301- is_sentencepiece = self .is_sentencepiece
302- is_hf_tokenizer = self .is_hf_tokenizer
301+ is_tiktoken = self .is_tiktoken ()
302+ is_sentencepiece = self .is_sentencepiece ()
303+ is_hf_tokenizer = self .is_hf_tokenizer ()
304+
303305 use_tiktoken = model .config .use_tiktoken
304306 use_hf_tokenizer = model .config .use_hf_tokenizer
305307 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
@@ -651,13 +653,13 @@ def do_nothing(max_batch_size, max_seq_length):
651653 model = torch .load (builder_args .snapshot_path , weights_only = False )
652654 except Exception :
653655 raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
654- # _active_backend() does not allow DSO & AOTI to be true.
656+ # _active_backend() does not allow DSO & AOTI to be true.
655657 # Choose either.
656658 from torchchat .utils .build_utils import set_backend
657659 set_backend (dso = True , pte = False , aoti_package = False )
658660 if (model .config != config ):
659661 raise RuntimeError ("loaded model architecture mismatch" )
660- ##
662+ ##
661663 ## import all libraries with custom kernels ans custom operators
662664 ## that quantize may be pulling in
663665 ##
@@ -792,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
792794 return "TikToken"
793795 if tokenizers :
794796 return "Tokenizers"
795- return "SentencePiece"
797+ return "SentencePiece"
0 commit comments