@@ -711,6 +711,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
711
711
if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15" :
712
712
# ref: https://huggingface.co/trillionlabs/Trillion-7B-preview
713
713
res = "trillion"
714
+ if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224" :
715
+ # ref: https://huggingface.co/inclusionAI/Ling-lite
716
+ res = "bailingmoe"
714
717
715
718
if res is None :
716
719
logger .warning ("\n " )
@@ -5133,6 +5136,108 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
5133
5136
return super ().modify_tensors (data_torch , name , bid )
5134
5137
5135
5138
5139
+ @Model .register ("BailingMoeForCausalLM" )
5140
+ class BailingMoeModel (Model ):
5141
+ model_arch = gguf .MODEL_ARCH .BAILINGMOE
5142
+
5143
+ def set_vocab (self ):
5144
+ self ._set_vocab_gpt2 ()
5145
+
5146
+ def set_gguf_parameters (self ):
5147
+ super ().set_gguf_parameters ()
5148
+ hparams = self .hparams
5149
+ if "head_dim" in hparams :
5150
+ rope_dim = hparams ["head_dim" ]
5151
+ else :
5152
+ rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
5153
+
5154
+ self .gguf_writer .add_rope_dimension_count (rope_dim )
5155
+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
5156
+ self .gguf_writer .add_leading_dense_block_count (hparams ["first_k_dense_replace" ])
5157
+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
5158
+ self .gguf_writer .add_expert_feed_forward_length (hparams ["moe_intermediate_size" ])
5159
+ self .gguf_writer .add_expert_weights_scale (1.0 )
5160
+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
5161
+ self .gguf_writer .add_expert_shared_count (hparams ["num_shared_experts" ])
5162
+ self .gguf_writer .add_expert_weights_norm (hparams ["norm_topk_prob" ])
5163
+
5164
+ _experts : list [dict [str , Tensor ]] | None = None
5165
+
5166
+ @staticmethod
5167
+ def permute (weights : Tensor , n_head : int , n_head_kv : int | None ):
5168
+ if n_head_kv is not None and n_head != n_head_kv :
5169
+ n_head = n_head_kv
5170
+ return (weights .reshape (n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :])
5171
+ .swapaxes (1 , 2 )
5172
+ .reshape (weights .shape ))
5173
+
5174
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
5175
+ n_head = self .hparams ["num_attention_heads" ]
5176
+ n_kv_head = self .hparams .get ("num_key_value_heads" )
5177
+ n_embd = self .hparams ["hidden_size" ]
5178
+ head_dim = self .hparams .get ("head_dim" , n_embd // n_head )
5179
+
5180
+ output_name = self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT )
5181
+
5182
+ if name .endswith ("attention.dense.weight" ):
5183
+ return [(self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_OUT , bid ), data_torch )]
5184
+ elif name .endswith ("query_key_value.weight" ):
5185
+ q , k , v = data_torch .split ([n_head * head_dim , n_kv_head * head_dim , n_kv_head * head_dim ], dim = - 2 )
5186
+
5187
+ return [
5188
+ (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_Q , bid ), BailingMoeModel .permute (q , n_head , n_head )),
5189
+ (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_K , bid ), BailingMoeModel .permute (k , n_head , n_kv_head )),
5190
+ (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_V , bid ), v )
5191
+ ]
5192
+ elif name .find ("mlp.experts" ) != - 1 :
5193
+ n_experts = self .hparams ["num_experts" ]
5194
+ assert bid is not None
5195
+
5196
+ tensors : list [tuple [str , Tensor ]] = []
5197
+
5198
+ if self ._experts is None :
5199
+ self ._experts = [{} for _ in range (self .block_count )]
5200
+
5201
+ self ._experts [bid ][name ] = data_torch
5202
+
5203
+ if len (self ._experts [bid ]) >= n_experts * 3 :
5204
+ # merge the experts into a single 3d tensor
5205
+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
5206
+ datas : list [Tensor ] = []
5207
+
5208
+ for xid in range (n_experts ):
5209
+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
5210
+ datas .append (self ._experts [bid ][ename ])
5211
+ del self ._experts [bid ][ename ]
5212
+
5213
+ data_torch = torch .stack (datas , dim = 0 )
5214
+
5215
+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
5216
+
5217
+ new_name = self .map_tensor_name (merged_name )
5218
+
5219
+ tensors .append ((new_name , data_torch ))
5220
+
5221
+ return tensors
5222
+
5223
+ new_name = self .map_tensor_name (name )
5224
+
5225
+ if new_name == output_name and self .hparams .get ("norm_head" ):
5226
+ data_torch = data_torch .float ()
5227
+ data_torch /= torch .norm (data_torch , p = 2 , dim = 0 , keepdim = True ) + 1e-7
5228
+
5229
+ return [(new_name , data_torch )]
5230
+
5231
+ def prepare_tensors (self ):
5232
+ super ().prepare_tensors ()
5233
+
5234
+ if self ._experts is not None :
5235
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
5236
+ experts = [k for d in self ._experts for k in d .keys ()]
5237
+ if len (experts ) > 0 :
5238
+ raise ValueError (f"Unprocessed experts: { experts } " )
5239
+
5240
+
5136
5241
@Model .register ("ChameleonForConditionalGeneration" )
5137
5242
@Model .register ("ChameleonForCausalLM" ) # obsolete
5138
5243
class ChameleonModel (Model ):
0 commit comments