Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions snapkv/monkeypatch/llama_hijack_4_37.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,89 @@ def llama_flash_attn2_forward(

return attn_output, attn_weights, past_key_value

def llama_sdpa_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# [SnapKV] register kv_cluster
init_snapkv(self)
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
if self.kv_seq_len != 0:
kv_seq_len += self.kv_seq_len
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [SnapKV] move to ahead
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Update SG
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
else:
self.kv_seq_len += q_len
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)


# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value

def prepare_inputs_for_generation_llama(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
Expand Down
8 changes: 7 additions & 1 deletion snapkv/monkeypatch/monkeypatch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from importlib.metadata import version
from transformers.utils import is_flash_attn_2_available
import warnings
import transformers
from snapkv.monkeypatch.llama_hijack_4_37 import llama_flash_attn2_forward as llama_flash_attn2_forward_4_37, prepare_inputs_for_generation_llama as prepare_inputs_for_generation_llama_4_37
from snapkv.monkeypatch.mistral_hijack_4_37 import mistral_flash_attn2_forward as mistral_flash_attn2_forward_4_37, prepare_inputs_for_generation_mistral as prepare_inputs_for_generation_mistral_4_37
from snapkv.monkeypatch.mixtral_hijack_4_37 import mixtral_flash_attn2_forward as mixtral_flash_attn2_forward_4_37, prepare_inputs_for_generation_mixtral as prepare_inputs_for_generation_mixtral_4_37
from snapkv.monkeypatch.llama_hijack_4_37 import llama_sdpa_attn_forward as llama_sdpa_attn_forward_4_37

def check_version():
try:
Expand All @@ -23,7 +25,11 @@ def replace_llama():
if warning_flag:
warnings.warn(f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}.")
transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_llama_4_37
transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_4_37

if is_flash_attn_2_available():
transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_4_37
else: # pytorch implementation
transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward = llama_sdpa_attn_forward_4_37

def replace_mistral():
transformers_version = check_version()
Expand Down