@@ -147,6 +147,81 @@ def __post_init__(self):
147147 self .head_dim = self .dim // self .n_heads
148148
149149
150+ class Rope (torch .nn .Module ):
151+ def __init__ (self , params : ModelArgs ):
152+ super ().__init__ ()
153+ self .params = params
154+ if self .params .use_hf_rope :
155+ self .precompute_freqs_cis = hf_precompute_freqs_cis
156+ else :
157+ self .precompute_freqs_cis = partial (
158+ precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
159+ )
160+ freqs_cos , freqs_sin = self .precompute_freqs_cis (
161+ self .params .head_dim ,
162+ (
163+ self .params .max_seq_len # Normal llama2.
164+ if self .params .ffn_dim_multiplier is None
165+ else self .params .max_seq_len * 2 # Sharded checkpoint.
166+ ),
167+ self .params .rope_freq_base ,
168+ )
169+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
170+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
171+ if self .params .use_hf_rope :
172+ self .apply_rotary_emb = hf_apply_rotary_emb
173+ else :
174+ self .apply_rotary_emb = RotaryEmbedding ()
175+
176+ def forward (
177+ self ,
178+ q : torch .Tensor ,
179+ k : torch .Tensor ,
180+ freqs_cos : torch .Tensor ,
181+ freqs_sin : torch .Tensor ,
182+ ):
183+ return self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
184+
185+ def get_freqs (self , input_pos : Optional [torch .Tensor ], seq_len : int ):
186+ """
187+ Get the precomputed frequencies for the given input position and sequence length.
188+
189+ Args:
190+ input_pos (torch.Tensor): The input position tensor.
191+ seq_len (int): The sequence length.
192+
193+ Returns:
194+ Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
195+ """
196+ if self .params .use_kv_cache :
197+ assert (
198+ input_pos is not None
199+ ), "input_pos must be provided when use_kv_cache is True"
200+
201+ if self .params .enable_dynamic_shape :
202+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
203+ input_pos_item = input_pos [- 1 ].item ()
204+ torch ._check_is_size (input_pos_item )
205+ torch ._check (input_pos_item < self .params .max_seq_len )
206+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
207+ freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
208+ # pyre-ignore: Incompatible parameter type [6]
209+ freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seq_len )
210+ else :
211+ # When not using dynamic shape, use of the .item results in
212+ # symints, due to querying the data from tensor.
213+ # this path avoids that for mps backend, although probably mps backend
214+ # can support dynamic shape?
215+ freqs_cos = self .freqs_cos [input_pos ]
216+ freqs_sin = self .freqs_sin [input_pos ]
217+
218+ else :
219+ assert input_pos is None , "input_pos is unused when use_kv_cache is False"
220+ freqs_cos = self .freqs_cos [:seq_len ]
221+ freqs_sin = self .freqs_sin [:seq_len ]
222+ return freqs_cos , freqs_sin
223+
224+
150225class KVCache (nn .Module ):
151226 def __init__ (
152227 self ,
@@ -266,7 +341,7 @@ def forward(
266341
267342
268343class Attention (nn .Module ):
269- def __init__ (self , args : ModelArgs , layer_id : int ):
344+ def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
270345 super ().__init__ ()
271346 self .use_kv_cache = args .use_kv_cache
272347 self .n_heads = args .n_heads
@@ -287,6 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
287362
288363 self .layer_id = layer_id
289364
365+ self .rope = rope
366+
290367 causal_mask = torch .tril (
291368 torch .ones (
292369 self .max_seq_len ,
@@ -303,7 +380,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
303380 args .max_seq_len ,
304381 self .n_kv_heads ,
305382 self .head_dim ,
306- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
383+ not args .use_sdpa_with_kv_cache_op , # if we are using the custom op don't transpose the cache. Expect untransposed q k v
307384 args .enable_dynamic_shape ,
308385 )
309386 self .SDPA = SDPA (
@@ -314,10 +391,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
314391 max_seq_len = self .max_seq_len ,
315392 enable_dynamic_shape = args .enable_dynamic_shape ,
316393 )
317- if args .use_hf_rope :
318- self .apply_rotary_emb = hf_apply_rotary_emb
319- else :
320- self .apply_rotary_emb = RotaryEmbedding ()
321394
322395 def forward (
323396 self ,
@@ -336,7 +409,7 @@ def forward(
336409 v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
337410
338411 # RoPE relative positional embeddings
339- q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
412+ q , k = self .rope . forward (q , k , freqs_cos , freqs_sin )
340413
341414 if self .use_kv_cache :
342415 assert input_pos is not None
@@ -424,13 +497,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
424497
425498
426499class TransformerBlock (nn .Module ):
427- def __init__ (self , layer_id : int , args : ModelArgs ):
500+ def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
428501 super ().__init__ ()
429502 self .use_kv_cache = args .use_kv_cache
430503 self .n_heads = args .n_heads
431504 self .dim = args .dim
432505 self .head_dim = args .head_dim
433- self .attention = Attention (args , layer_id )
506+ self .attention = Attention (args , layer_id , rope )
434507 if args .moe :
435508 self .block_sparse_moe = MOEFeedForward (args )
436509 else :
@@ -459,33 +532,17 @@ def __init__(self, params: ModelArgs):
459532 self .n_layers = params .n_layers
460533
461534 self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
535+ self .rope = Rope (params )
462536 self .layers = torch .nn .ModuleList ()
463537 for layer_id in range (params .n_layers ):
464- self .layers .append (TransformerBlock (layer_id , params ))
538+ self .layers .append (TransformerBlock (layer_id , params , self . rope ))
465539 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
466540 self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
467541 self .use_kv_cache = params .use_kv_cache
468542 self .generate_full_logits = params .generate_full_logits
469543 self .max_seq_len = params .max_seq_len
470544 self .input_prune_map = params .input_prune_map
471545 self .output_prune_map = params .output_prune_map
472- if params .use_hf_rope :
473- self .precompute_freqs_cis = hf_precompute_freqs_cis
474- else :
475- self .precompute_freqs_cis = partial (
476- precompute_freqs_cis , use_scaled = params .use_scaled_rope
477- )
478- freqs_cos , freqs_sin = self .precompute_freqs_cis (
479- params .head_dim ,
480- (
481- params .max_seq_len # Normal llama2.
482- if params .ffn_dim_multiplier is None
483- else params .max_seq_len * 2 # Sharded checkpoint.
484- ),
485- params .rope_freq_base ,
486- )
487- self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
488- self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
489546
490547 def forward (
491548 self ,
@@ -502,33 +559,7 @@ def forward(
502559 if tokens is not None and h is None :
503560 h = self .tok_embeddings (tokens )
504561 seqlen = h .shape [1 ]
505-
506- if self .use_kv_cache :
507- assert (
508- input_pos is not None
509- ), "input_pos must be provided when use_kv_cache is True"
510-
511- if self .params .enable_dynamic_shape :
512- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
513- input_pos_item = input_pos [- 1 ].item ()
514- torch ._check_is_size (input_pos_item )
515- torch ._check (input_pos_item < self .params .max_seq_len )
516- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
517- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
518- # pyre-ignore: Incompatible parameter type [6]
519- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
520- else :
521- # When not using dynamic shape, use of the .item results in
522- # symints, due to querying the data from tensor.
523- # this path avoids that for mps backend, although probably mps backend
524- # can support dynamic shape?
525- freqs_cos = self .freqs_cos [input_pos ]
526- freqs_sin = self .freqs_sin [input_pos ]
527-
528- else :
529- assert input_pos is None , "input_pos is unused when use_kv_cache is False"
530- freqs_cos = self .freqs_cos [:seqlen ]
531- freqs_sin = self .freqs_sin [:seqlen ]
562+ freqs_cos , freqs_sin = self .rope .get_freqs (input_pos , seqlen )
532563
533564 for layer in self .layers :
534565 h = layer (
0 commit comments