Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 97 additions & 5 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from packaging import version

from ..activations import ACT2FN
from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging
from ..utils.quantization_config import (
Expand Down Expand Up @@ -46,7 +47,6 @@
"mlp": ["w1", "w3", "w2"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
"rope_theta": 1000000.0,
},
"llama": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
Expand All @@ -60,6 +60,18 @@
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
"qwen2": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
"qwen3": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
}

AWQ_SCALES_MAPPINGS = {
Expand All @@ -74,6 +86,53 @@
}


if is_auto_awq_available():
from awq.modules.fused.attn import RoPE

class AWQRoPE(RoPE):
"""
AWQRoPE module for hacking rope implementation in AWQ fused attention modules to support more models.

Args:
rope_type (`str`):
The rope type to use.
head_dim (`int`):
The head dimension.
max_seq_len (`int`):
The maximum sequence length.
config (`PreTrainedConfig`):
The model config object.
device (`torch.device`):
The device to put the module on.
"""

def __init__(self, rope_type, head_dim, max_seq_len, config, device):
rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
self.inv_freq, self.attention_scaling = rope_init_fn(config, device)
# Use fake rope_theta to initialize the parent class
super().__init__(head_dim=head_dim, max_seq_len=max_seq_len, device=device, rope_theta=-1)

def precompute_freqs_cis(self, dim: int, end: int, theta=-1):
t = torch.arange(end, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
del self.inv_freq # free the memory
return freqs_cis

def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
start_pos: int,
seqlen: int,
partial: bool = False,
):
xq_out, xk_out = super().forward(xq, xk, start_pos, seqlen, partial)
xq_out = (xq_out * self.attention_scaling).type_as(xq)
xk_out = (xk_out * self.attention_scaling).type_as(xk)
return xq_out, xk_out


def replace_quantization_scales(model, model_type):
from awq.modules.act import ScaledActivation

Expand Down Expand Up @@ -219,15 +278,17 @@ def get_modules_to_fuse(model, quantization_config):
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
config = model.config.get_text_config(decoder=True)

# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
# Handle hidden_size, num_attention_heads, num_key_value_heads, rope_parameters on our own.
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
rope_parameters = config.rope_parameters

# Fill `current_fused_mapping` with the expected values
current_fused_mapping["hidden_size"] = hidden_size
current_fused_mapping["num_attention_heads"] = num_attention_heads
current_fused_mapping["num_key_value_heads"] = num_key_value_heads
current_fused_mapping["rope_parameters"] = rope_parameters
current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
else:
raise ValueError(
Expand Down Expand Up @@ -261,6 +322,15 @@ def fuse_awq_modules(model, quantization_config):
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

# Hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value
old_quant_attention_fused_forward = QuantAttentionFused.forward

def new_quant_attention_fused_forward(self, *args, **kwargs):
attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs)
return attn_output, attention_weight

QuantAttentionFused.forward = new_quant_attention_fused_forward
else:
raise ValueError("Fusing is only supported for the AutoAWQ backend")

Expand Down Expand Up @@ -376,7 +446,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
The pytorch parent module that has layernorm modules to fuse
modules_to_fuse (`list[str]`):
The module fusing mapping. The dictionary has to contain a field `attention` with attention module names
in the correct order: q, k, v, o layer
in the correct order: q, k, v, o layer, (q_norm, k_norm) optional
current_module_name (`str`):
The current submodule name
target_cls (`~autoawq.QuantAttentionFused`):
Expand Down Expand Up @@ -415,6 +485,14 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
v_proj = getattr(module, modules_to_fuse["attention"][2])
o_proj = getattr(module, modules_to_fuse["attention"][3])

# maybe there are q_norm and k_norm layers
if len(modules_to_fuse["attention"]) > 4:
q_norm = getattr(module, modules_to_fuse["attention"][4])
k_norm = getattr(module, modules_to_fuse["attention"][5])
else:
q_norm = None
k_norm = None

bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

qkv_layer = linear_target_cls(
Expand Down Expand Up @@ -445,16 +523,30 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
modules_to_fuse["max_seq_len"],
use_alibi=modules_to_fuse["use_alibi"],
# The default value in autoawq is set to 10000.0
rope_theta=modules_to_fuse.get("rope_theta", 10000.0),
rope_theta=modules_to_fuse["rope_parameters"].get("rope_theta", 10000.0),
q_norm=q_norm,
k_norm=k_norm,
)

# Hack the rope module if not using alibi and rope_type is not default
# As the default rope implementation in autoawq only supports the "default" rope type
rope_type = modules_to_fuse["rope_parameters"].get("rope_type", "default")
if not modules_to_fuse["use_alibi"] and rope_type != "default":
fused_attention_layer.rope = AWQRoPE(
rope_type,
modules_to_fuse["hidden_size"] // modules_to_fuse["num_attention_heads"],
modules_to_fuse["max_seq_len"],
model.config.get_text_config(decoder=True),
previous_device,
)

fused_attention_layer.is_hf_transformers = True

parent_name, child_name = current_module_name.rsplit(".", 1)
parent = model.get_submodule(parent_name)
setattr(parent, child_name, fused_attention_layer.to(previous_device))

del q_proj, k_proj, v_proj, o_proj
del q_proj, k_proj, v_proj, o_proj, q_norm, k_norm
module_has_been_fused = True

return module_has_been_fused
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/apertus/modeling_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bitnet/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/csm/modeling_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/ernie4_5/modeling_ernie4_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glm4/modeling_glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glm4_moe/modeling_glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/helium/modeling_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't mind passing this but I didn't find where this is used in decoder layer -> attention layer

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very good question.

On one hand, when using model.generate, use_cache is set to True, which enables the model to utilize past_key_values. At this point, the logic in autoawq checks whether the forward call originates from generate by inspecting use_cache, and accordingly adjusts the starting position of its precomputed RoPE embeddings. If use_cache is not passed down to the decoder layer and subsequently to the attention module, autoawq cannot determine whether it is inside a generate call. Consequently, it assumes the forward pass is always a regular one (i.e., without any cache), keeping the starting position fixed at 0, which leads to garbled output during inference.

autoawq:

https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/modules/fused/attn.py#L218-L241

On the other hand, similar to the implementations in Qwen2 and Qwen3, use_cache is indeed passed to the decoder layer and then forwarded to the attention module—but it is not actually used within the attention module itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm I see, thanks for the extensive explanation !

cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/olmo2/modeling_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/seed_oss/modeling_seed_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
Expand Down
Loading