@@ -136,7 +136,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
136136 calc_ff = (((8 * n_embd ) // 3 + n_mult - 1 ) // n_mult )* n_mult
137137 if calc_ff == n_ff :
138138 return n_mult
139- return 1
139+ raise Exception ( f"failed to find n_mult for (n_ff= { n_ff } , n_embd= { n_embd } )." )
140140
141141@dataclass
142142class Params :
@@ -321,6 +321,10 @@ def astype(self, data_type: DataType) -> 'Tensor': ...
321321 @abstractmethod
322322 def permute (self , n_head : int ) -> 'Tensor' : ...
323323 @abstractmethod
324+ def permute_part (self , n_part : int , n_head : int ) -> 'UnquantizedTensor' : ...
325+ @abstractmethod
326+ def part (self , n_part : int ) -> 'UnquantizedTensor' : ...
327+ @abstractmethod
324328 def to_ggml (self ) -> 'GGMLCompatibleTensor' : ...
325329
326330
@@ -345,6 +349,14 @@ def astype(self, data_type: DataType) -> Tensor:
345349 def to_ggml (self ) -> 'UnquantizedTensor' :
346350 return self
347351
352+ def permute_part (self , n_part : int , n_head : int ) -> 'UnquantizedTensor' :
353+ r = self .ndarray .shape [0 ] // 3
354+ return UnquantizedTensor (permute (self .ndarray [r * n_part : r * n_part + r , ...], n_head ))
355+
356+ def part (self , n_part : int ) -> 'UnquantizedTensor' :
357+ r = self .ndarray .shape [0 ] // 3
358+ return UnquantizedTensor (self .ndarray [r * n_part : r * n_part + r , ...])
359+
348360 def permute (self , n_head : int ) -> 'UnquantizedTensor' :
349361 return UnquantizedTensor (permute (self .ndarray , n_head ))
350362
@@ -642,6 +654,19 @@ def load() -> Tensor:
642654 return lazy_tensor .load ().permute (n_head )
643655 return LazyTensor (load , lazy_tensor .shape , lazy_tensor .data_type , f'permute({ n_head } ) ' + lazy_tensor .description )
644656
657+ def permute_part_lazy (lazy_tensor : LazyTensor , n_part : int , n_head : int ) -> LazyTensor :
658+ def load () -> Tensor :
659+ return lazy_tensor .load ().permute_part (n_part , n_head )
660+ s = lazy_tensor .shape .copy ()
661+ s [0 ] = s [0 ] // 3
662+ return LazyTensor (load , s , lazy_tensor .data_type , f'permute({ n_head } ) ' + lazy_tensor .description )
663+
664+ def part_lazy (lazy_tensor : LazyTensor , n_part : int ) -> LazyTensor :
665+ def load () -> Tensor :
666+ return lazy_tensor .load ().part (n_part )
667+ s = lazy_tensor .shape .copy ()
668+ s [0 ] = s [0 ] // 3
669+ return LazyTensor (load , s , lazy_tensor .data_type , 'part ' + lazy_tensor .description )
645670
646671def convert_transformers_to_orig (model : LazyModel , params : Params ) -> LazyModel :
647672 out : LazyModel = {}
@@ -650,11 +675,17 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
650675 out ["output.weight" ] = model ["lm_head.weight" ]
651676
652677 for i in itertools .count ():
653- if f"model.layers.{ i } .self_attn.q_proj.weight" not in model :
678+ if f"model.layers.{ i } .self_attn.q_proj.weight" in model :
679+ out [f"layers.{ i } .attention.wq.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.q_proj.weight" ], params .n_head )
680+ out [f"layers.{ i } .attention.wk.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.k_proj.weight" ], params .n_head )
681+ out [f"layers.{ i } .attention.wv.weight" ] = model [f"model.layers.{ i } .self_attn.v_proj.weight" ]
682+ elif f"model.layers.{ i } .self_attn.W_pack.weight" in model :
683+ out [f"layers.{ i } .attention.wq.weight" ] = permute_part_lazy (model [f"model.layers.{ i } .self_attn.W_pack.weight" ], 0 , params .n_head )
684+ out [f"layers.{ i } .attention.wk.weight" ] = permute_part_lazy (model [f"model.layers.{ i } .self_attn.W_pack.weight" ], 1 , params .n_head )
685+ out [f"layers.{ i } .attention.wv.weight" ] = part_lazy (model [f"model.layers.{ i } .self_attn.W_pack.weight" ], 2 )
686+ else :
654687 break
655- out [f"layers.{ i } .attention.wq.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.q_proj.weight" ], params .n_head )
656- out [f"layers.{ i } .attention.wk.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.k_proj.weight" ], params .n_head )
657- out [f"layers.{ i } .attention.wv.weight" ] = model [f"model.layers.{ i } .self_attn.v_proj.weight" ]
688+
658689 out [f"layers.{ i } .attention.wo.weight" ] = model [f"model.layers.{ i } .self_attn.o_proj.weight" ]
659690
660691 out [f"layers.{ i } .feed_forward.w1.weight" ] = model [f"model.layers.{ i } .mlp.gate_proj.weight" ]
0 commit comments