1616import torch ._inductor .config
1717import torch .distributed as dist
1818
19- from torchtune .models .convert_weights import meta_to_tune
20-
21- from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
22-
23- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
24-
25- from torchtune .training import set_default_dtype
26-
27- from torchchat .distributed .logging_utils import SingletonLogger
28-
29- from torchchat .distributed .utils import (
19+ from torchchat .distributed .utils import (
3020 Color as color ,
3121 CUDATrackTime ,
32- GPUMemoryMonitor ,
3322 init_distributed ,
23+ GPUMemoryMonitor ,
3424)
25+ from torchchat .distributed .logging_utils import SingletonLogger
3526
3627from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
3728from torchchat .model_config .model_config import resolve_model_config
4536from torchchat .utils .quantize import quantize_model
4637
4738
39+ from torchtune .models .convert_weights import meta_to_tune
40+
41+ from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
42+
43+ from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
44+
45+ from torchtune .training import set_default_dtype
46+
47+
4848@dataclass
4949class BuilderArgs :
5050 checkpoint_path : Optional [Union [Path , str ]] = None
@@ -189,19 +189,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189189 tp = getattr (args , "tp" , 1 )
190190 chpt_from = getattr (args , "chpt_from" , "hf" )
191191 sdp_backend_dict = {
192- " math" : torch .nn .attention .SDPBackend .MATH ,
193- " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
194- " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
195- " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
192+ ' math' : torch .nn .attention .SDPBackend .MATH ,
193+ ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
194+ ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
195+ ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
196196 }
197197 attention_backend = sdp_backend_dict [args .attention_backend ]
198- if args .device == "cpu" and (
199- args .attention_backend == "efficient_attention"
200- or args .attention_backend == "cudnn_attention"
201- ):
202- print (
203- f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
204- )
198+ if args .device == "cpu" and (args .attention_backend == "efficient_attention"
199+ or args .attention_backend == "cudnn_attention" ):
200+ print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
205201 attention_backend = torch .nn .attention .SDPBackend .MATH
206202 return cls (
207203 checkpoint_dir = checkpoint_dir ,
@@ -250,29 +246,13 @@ class TokenizerArgs:
250246 is_sentencepiece : bool = False
251247 is_tiktoken : bool = False
252248 is_hf_tokenizer : bool = False
253- is_llama_3_2_mm : bool = False
254249 t : Optional [Any ] = None
255250
256251 def __post_init__ (self ):
257- # special handling for llama-3.2-mm
258- if "llama-3.2-11b-vision" in str (self .tokenizer_path ).lower ():
259- try :
260- from torchtune .models .llama3_2_vision import llama3_2_vision_transform
261-
262- self .t = llama3_2_vision_transform (path = str (self .tokenizer_path ))
263- self .is_llama_3_2_mm = True
264- self .is_tiktoken = False
265- self .is_sentencepiece = False
266- self .is_hf_tokenizer = False
267- return
268- except :
269- pass
270-
271252 try :
272253 from tokenizer .tiktoken import Tokenizer as TiktokenTokenizer
273254
274255 self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
275- self .is_llama_3_2_mm = False
276256 self .is_tiktoken = True
277257 self .is_sentencepiece = False
278258 self .is_hf_tokenizer = False
@@ -284,7 +264,6 @@ def __post_init__(self):
284264 from sentencepiece import SentencePieceProcessor
285265
286266 self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
287- self .is_llama_3_2_mm = False
288267 self .is_tiktoken = False
289268 self .is_sentencepiece = True
290269 self .is_hf_tokenizer = False
@@ -296,15 +275,13 @@ def __post_init__(self):
296275 from tokenizer .hf_tokenizer import HFTokenizer
297276
298277 self .t = HFTokenizer (str (self .tokenizer_path ))
299- self .is_llama_3_2_mm = False
300278 self .is_tiktoken = False
301279 self .is_sentencepiece = False
302280 self .is_hf_tokenizer = True
303281 return
304282 except :
305283 pass
306284
307- self .is_llama_3_2_mm = False
308285 self .is_tiktoken = False
309286 self .is_sentencepiece = False
310287 self .is_hf_tokenizer = False
@@ -319,32 +296,20 @@ def validate_model(
319296 if model is None :
320297 return
321298
322- if (
323- sum (
324- [
325- self .is_tiktoken ,
326- self .is_hf_tokenizer ,
327- self .is_sentencepiece ,
328- self .is_llama_3_2_mm ,
329- ]
330- )
331- != 1
332- ):
299+ if sum ([self .is_tiktoken , self .is_hf_tokenizer , self .is_sentencepiece ]) != 1 :
333300 raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
334301
335302 is_tiktoken = self .is_tiktoken
336303 is_sentencepiece = self .is_sentencepiece
337304 is_hf_tokenizer = self .is_hf_tokenizer
338- is_llama_3_2_mm = self .is_llama_3_2_mm
339-
340305 use_tiktoken = model .config .use_tiktoken
341306 use_hf_tokenizer = model .config .use_hf_tokenizer
342- use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer )
307+ use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
308+
343309 if (
344- (is_tiktoken and not use_tiktoken )
345- or (is_hf_tokenizer and not use_hf_tokenizer )
346- or (is_sentencepiece and not use_other_tokenizer )
347- or (is_llama_3_2_mm and not use_other_tokenizer )
310+ (is_tiktoken and not use_tiktoken ) or
311+ (is_hf_tokenizer and not use_hf_tokenizer ) or
312+ (is_sentencepiece and not use_sentencepiece )
348313 ):
349314 raise RuntimeError (
350315 "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -542,7 +507,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
542507 # AOTI-compoiled model will load its own weights.
543508 # Release weights here to avoid OOM
544509 import gc
545-
546510 if hasattr (model , "model" ):
547511 model .model = None
548512 gc .collect ()
@@ -600,7 +564,6 @@ def _initialize_model(
600564
601565 def do_nothing (max_batch_size , max_seq_length ):
602566 pass
603-
604567 model .setup_caches = do_nothing
605568
606569 model .forward = torch ._export .aot_load (
@@ -638,7 +601,6 @@ def do_nothing(max_batch_size, max_seq_length):
638601
639602 def do_nothing (max_batch_size , max_seq_length ):
640603 pass
641-
642604 model .setup_caches = do_nothing
643605
644606 model .forward = aoti_compiled_model
@@ -713,9 +675,7 @@ def do_nothing(max_batch_size, max_seq_length):
713675 logger = SingletonLogger .get_logger ()
714676
715677 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
716- logger .info (
717- f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
718- )
678+ logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
719679
720680 # Model-level config
721681 if builder_args .params_table :
@@ -726,16 +686,20 @@ def do_nothing(max_batch_size, max_seq_length):
726686 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
727687 logger .info (f"Transformer Config: { config } " )
728688
729- # TODO: Move into head of file after solving circular import
730- from torchchat .distributed .checkpoint_utils import load_model_weights
689+ #TODO: Move into head of file after solving circular import
690+ from torchchat .distributed .checkpoint_utils import (
691+ load_model_weights ,
692+ )
731693
732694 # Validate pipeline degree
733695 assert config .n_layers % pp_degree == 0
734696
735697 # Create device mesh
736698 device_mesh = dist .init_device_mesh (
737- "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
738- )
699+ "cuda" ,
700+ (pp_degree , tp_degree ),
701+ mesh_dim_names = ("pp" , "tp" )
702+ )
739703 tp_mesh = device_mesh ["tp" ]
740704 pp_mesh = device_mesh ["pp" ]
741705 logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -764,13 +728,7 @@ def do_nothing(max_batch_size, max_seq_length):
764728 # Load weights
765729 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
766730 with CUDATrackTime () as timer :
767- load_model_weights (
768- model ,
769- builder_args .distribution_path ,
770- device ,
771- config ,
772- builder_args .chpt_from ,
773- )
731+ load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
774732
775733 logger .info (
776734 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -784,7 +742,7 @@ def do_nothing(max_batch_size, max_seq_length):
784742 # lanes.
785743 # TODO: bump up the lane count
786744 pipeline_lanes = 1
787- seqlen_prefill = 1024
745+ seqlen_prefill = 1024
788746 with device :
789747 model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
790748
@@ -836,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
836794 return "TikToken"
837795 if tokenizers :
838796 return "Tokenizers"
839- return "SentencePiece"
797+ return "SentencePiece"
0 commit comments