@@ -66,8 +66,6 @@ class ModelBase:
6666 part_names : list [str ]
6767 is_safetensors : bool
6868 hparams : dict [str , Any ]
69- block_count : int
70- tensor_map : gguf .TensorNameMap
7169 tensor_names : set [str ] | None
7270 gguf_writer : gguf .GGUFWriter
7371 model_name : str | None
@@ -78,6 +76,10 @@ class ModelBase:
7876 # subclasses should define this!
7977 model_arch : gguf .MODEL_ARCH
8078
79+ # subclasses should initialize this!
80+ block_count : int
81+ tensor_map : gguf .TensorNameMap
82+
8183 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool = False ,
8284 use_temp_file : bool = False , eager : bool = False ,
8385 metadata_override : Path | None = None , model_name : str | None = None ,
@@ -113,8 +115,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113115 if not self .is_safetensors :
114116 self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
115117 self .hparams = ModelBase .load_hparams (self .dir_model ) if hparams is None else hparams
116- self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
117- self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
118118 self .tensor_names = None
119119 self .metadata_override = metadata_override
120120 self .model_name = model_name
@@ -418,14 +418,7 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
418418 @staticmethod
419419 def load_hparams (dir_model : Path ):
420420 with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
421- hparams = json .load (f )
422- architectures = hparams .get ("architectures" )
423- if "text_config" in hparams :
424- hparams = {** hparams , ** hparams ["text_config" ]}
425- if architectures is not None :
426- # preserve "architectures" from root level config
427- hparams ["architectures" ] = architectures
428- return hparams
421+ return json .load (f )
429422
430423 @classmethod
431424 def register (cls , * names : str ) -> Callable [[AnyModel ], AnyModel ]:
@@ -454,6 +447,16 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454447
455448
456449class TextModel (ModelBase ):
450+ def __init__ (self , * args , ** kwargs ):
451+ super ().__init__ (* args , ** kwargs )
452+
453+ if "text_config" in self .hparams :
454+ # move the text_config to the root level
455+ self .hparams = {** self .hparams , ** self .hparams ["text_config" ]}
456+
457+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
458+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
459+
457460 @classmethod
458461 def __init_subclass__ (cls ):
459462 # can't use an abstract property, because overriding it without type errors
@@ -1078,8 +1081,12 @@ def __init__(self, *args, **kwargs):
10781081 raise TypeError ("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION" )
10791082
10801083 # small hack to correct the number of layers
1081- self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , 128 )
1082- self .n_embd_text = self .find_hparam (["hidden_size" , "n_embd" ])
1084+ self .block_count = 512 # vision models are small, this "ought to be enough for anybody"
1085+ self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , self .block_count )
1086+
1087+ # get n_embd of the text model
1088+ text_config = {** self .hparams , ** self .hparams ["text_config" ]}
1089+ self .n_embd_text = text_config .get ("hidden_size" , text_config .get ("n_embd" , 0 ))
10831090 assert self .n_embd_text > 0 , "n_embd not found in hparams"
10841091
10851092 if "vision_config" not in self .hparams :
@@ -1726,20 +1733,20 @@ def prepare_tensors(self):
17261733 "LlamaForCausalLM" ,
17271734 "MistralForCausalLM" ,
17281735 "MixtralForCausalLM" ,
1729- "Idefics3ForConditionalGeneration" ,
1730- "SmolVLMForConditionalGeneration" ,
1736+ "VLlama3ForCausalLM" ,
17311737 "LlavaForConditionalGeneration" )
17321738class LlamaModel (TextModel ):
17331739 model_arch = gguf .MODEL_ARCH .LLAMA
17341740 undo_permute = True
17351741
17361742 def __init__ (self , * args , ** kwargs ):
17371743 super ().__init__ (* args , ** kwargs )
1744+ arch = get_model_architecture (self .dir_model , ModelType .TEXT , self .hparams )
17381745 # fix for SmolVLM2, missing `num_attention_heads` in config.json
1739- if self . hparams [ "architectures" ][ 0 ] == "SmolVLMForConditionalGeneration " :
1746+ if arch == "VLlama3ForCausalLM " :
17401747 self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
17411748 # fix for Pixtral, missing `num_attention_heads` in config.json
1742- if self . hparams [ "architectures" ][ 0 ] == "LlavaForConditionalGeneration" \
1749+ if arch == "LlavaForConditionalGeneration" \
17431750 and self .hparams .get ("model_type" ) == "mistral" :
17441751 self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
17451752
@@ -5805,6 +5812,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
58055812 return n
58065813
58075814
5815+ def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5816+ hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5817+ text_config = hparams .get ("text_config" , {})
5818+ vision_config = hparams .get ("vision_config" , {})
5819+ arch = hparams ["architectures" ][0 ]
5820+ # if "architectures" is found in the sub-config, use that instead
5821+ if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
5822+ arch = text_config ["architectures" ][0 ]
5823+ elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
5824+ arch = vision_config ["architectures" ][0 ]
5825+ return arch
5826+
5827+
58085828def main () -> None :
58095829 args = parse_args ()
58105830
@@ -5857,16 +5877,15 @@ def main() -> None:
58575877
58585878 logger .info (f"Loading model: { dir_model .name } " )
58595879
5860- hparams = ModelBase .load_hparams (dir_model )
5861-
58625880 if args .mmproj :
58635881 if "mmproj" not in fname_out .name :
58645882 fname_out = ModelBase .add_prefix_to_filename (fname_out , "mmproj-" )
58655883
58665884 with torch .inference_mode ():
58675885 output_type = ftype_map [args .outtype ]
5868- model_architecture = hparams ["architectures" ][0 ]
58695886 model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
5887+ model_architecture = get_model_architecture (dir_model , model_type )
5888+ logger .info (f"Model architecture: { model_architecture } " )
58705889 try :
58715890 model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
58725891 except NotImplementedError :
0 commit comments