From 63d7ca3790d5dadc1f9aae65e1922c71042baa73 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Tue, 28 Oct 2025 16:17:19 +0800 Subject: [PATCH 1/5] fix awq bc due to attention refactor --- src/transformers/integrations/awq.py | 41 ++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index c09da6c92e6c..1ec0aba143b1 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -60,6 +60,18 @@ "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], "use_alibi": False, }, + "qwen2": { + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], + "use_alibi": False, + }, + "qwen3": { + "attention": ["q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], + "use_alibi": False, + } } AWQ_SCALES_MAPPINGS = { @@ -219,15 +231,18 @@ def get_modules_to_fuse(model, quantization_config): # Properly deal with the case where we have a multi-modal model as well (e.g. Llava) config = model.config.get_text_config(decoder=True) - # Handle hidden_size, num_attention_heads, num_key_value_heads on our own. + # Handle hidden_size, num_attention_heads, num_key_value_heads and rope_theta on our own. hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads) + # The default value in autoawq is set to 10000.0 + rope_theta = getattr(config, "rope_theta", 10000.0) # Fill `current_fused_mapping` with the expected values current_fused_mapping["hidden_size"] = hidden_size current_fused_mapping["num_attention_heads"] = num_attention_heads current_fused_mapping["num_key_value_heads"] = num_key_value_heads + current_fused_mapping["rope_theta"] = rope_theta current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len else: raise ValueError( @@ -261,6 +276,13 @@ def fuse_awq_modules(model, quantization_config): from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.mlp import QuantFusedMLP from awq.modules.fused.norm import FasterTransformerRMSNorm + + # hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value + old_quant_attention_fused_forward = QuantAttentionFused.forward + def new_quant_attention_fused_forward(self, *args, **kwargs): + attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs) + return attn_output, attention_weight + QuantAttentionFused.forward = new_quant_attention_fused_forward else: raise ValueError("Fusing is only supported for the AutoAWQ backend") @@ -376,7 +398,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na The pytorch parent module that has layernorm modules to fuse modules_to_fuse (`list[str]`): The module fusing mapping. The dictionary has to contain a field `attention` with attention module names - in the correct order: q, k, v, o layer + in the correct order: q, k, v, o layer, (q_norm, k_norm) optional current_module_name (`str`): The current submodule name target_cls (`~autoawq.QuantAttentionFused`): @@ -415,6 +437,14 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na v_proj = getattr(module, modules_to_fuse["attention"][2]) o_proj = getattr(module, modules_to_fuse["attention"][3]) + # maybe there are q_norm and k_norm layers + if len(modules_to_fuse["attention"]) > 4: + q_norm = getattr(module, modules_to_fuse["attention"][4]) + k_norm = getattr(module, modules_to_fuse["attention"][5]) + else: + q_norm = None + k_norm = None + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None qkv_layer = linear_target_cls( @@ -444,8 +474,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na previous_device, modules_to_fuse["max_seq_len"], use_alibi=modules_to_fuse["use_alibi"], - # The default value in autoawq is set to 10000.0 - rope_theta=modules_to_fuse.get("rope_theta", 10000.0), + rope_theta=modules_to_fuse["rope_theta"], + q_norm=q_norm, + k_norm=k_norm, ) fused_attention_layer.is_hf_transformers = True @@ -454,7 +485,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na parent = model.get_submodule(parent_name) setattr(parent, child_name, fused_attention_layer.to(previous_device)) - del q_proj, k_proj, v_proj, o_proj + del q_proj, k_proj, v_proj, o_proj, q_norm, k_norm module_has_been_fused = True return module_has_been_fused From 1f49232b08b4c3c7899910ec04b9b22a84d2e053 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Fri, 7 Nov 2025 13:49:58 +0800 Subject: [PATCH 2/5] feat: support more rope_types for awq fusion --- src/transformers/integrations/awq.py | 72 +++++++++++++++++-- .../models/llama/modeling_llama.py | 1 + 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index 1ec0aba143b1..d06f503e03ad 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -18,6 +18,7 @@ from packaging import version from ..activations import ACT2FN +from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS from ..modeling_utils import PreTrainedModel from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging from ..utils.quantization_config import ( @@ -46,7 +47,6 @@ "mlp": ["w1", "w3", "w2"], "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], "use_alibi": False, - "rope_theta": 1000000.0, }, "llama": { "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], @@ -86,6 +86,52 @@ } +if is_auto_awq_available(): + from awq.modules.fused.attn import RoPE + + class AWQRoPE(RoPE): + """ + AWQRoPE module for hacking rope implementation in AWQ fused attention modules to support more models. + + Args: + rope_type (`str`): + The rope type to use. + head_dim (`int`): + The head dimension. + max_seq_len (`int`): + The maximum sequence length. + config (`PreTrainedConfig`): + The model config object. + device (`torch.device`): + The device to put the module on. + """ + def __init__(self, rope_type, head_dim, max_seq_len, config, device): + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + self.inv_freq, self.attention_scaling = rope_init_fn(config, device) + # Use fake rope_theta to initialize the parent class + super().__init__(head_dim=head_dim, max_seq_len=max_seq_len, device=device, rope_theta=-1) + + def precompute_freqs_cis(self, dim: int, end: int, theta=-1): + t = torch.arange(end, device=self.inv_freq.device) + freqs = torch.outer(t, self.inv_freq).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + del self.inv_freq # free the memory + return freqs_cis + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + start_pos: int, + seqlen: int, + partial: bool = False, + ): + xq_out, xk_out = super().forward(xq, xk, start_pos, seqlen, partial) + xq_out = (xq_out * self.attention_scaling).type_as(xq) + xk_out = (xk_out * self.attention_scaling).type_as(xk) + return xq_out, xk_out + + def replace_quantization_scales(model, model_type): from awq.modules.act import ScaledActivation @@ -231,18 +277,17 @@ def get_modules_to_fuse(model, quantization_config): # Properly deal with the case where we have a multi-modal model as well (e.g. Llava) config = model.config.get_text_config(decoder=True) - # Handle hidden_size, num_attention_heads, num_key_value_heads and rope_theta on our own. + # Handle hidden_size, num_attention_heads, num_key_value_heads, rope_parameters on our own. hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads) - # The default value in autoawq is set to 10000.0 - rope_theta = getattr(config, "rope_theta", 10000.0) + rope_parameters = config.rope_parameters # Fill `current_fused_mapping` with the expected values current_fused_mapping["hidden_size"] = hidden_size current_fused_mapping["num_attention_heads"] = num_attention_heads current_fused_mapping["num_key_value_heads"] = num_key_value_heads - current_fused_mapping["rope_theta"] = rope_theta + current_fused_mapping["rope_parameters"] = rope_parameters current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len else: raise ValueError( @@ -277,7 +322,7 @@ def fuse_awq_modules(model, quantization_config): from awq.modules.fused.mlp import QuantFusedMLP from awq.modules.fused.norm import FasterTransformerRMSNorm - # hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value + # Hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value old_quant_attention_fused_forward = QuantAttentionFused.forward def new_quant_attention_fused_forward(self, *args, **kwargs): attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs) @@ -474,11 +519,24 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na previous_device, modules_to_fuse["max_seq_len"], use_alibi=modules_to_fuse["use_alibi"], - rope_theta=modules_to_fuse["rope_theta"], + # The default value in autoawq is set to 10000.0 + rope_theta=modules_to_fuse["rope_parameters"].get("rope_theta", 10000.0), q_norm=q_norm, k_norm=k_norm, ) + # Hack the rope module if not using alibi and rope_type is not default + # As the default rope implementation in autoawq only supports the "default" rope type + rope_type = modules_to_fuse["rope_parameters"].get("rope_type", "default") + if not modules_to_fuse["use_alibi"] and rope_type != "default": + fused_attention_layer.rope = AWQRoPE( + rope_type, + modules_to_fuse["hidden_size"] // modules_to_fuse["num_attention_heads"], + modules_to_fuse["max_seq_len"], + model.config.get_text_config(decoder=True), + previous_device, + ) + fused_attention_layer.is_hf_transformers = True parent_name, child_name = current_module_name.rsplit(".", 1) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3d8340091bee..acd719b489ca 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -425,6 +425,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) From 9cdc80809e0dbf6cfb91d946751d93e458b54d74 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Wed, 12 Nov 2025 14:33:41 +0800 Subject: [PATCH 3/5] feat: add test for llama3 --- tests/quantization/autoawq/test_awq.py | 30 ++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 78c694a848fc..f8bfeae1818c 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -305,6 +305,9 @@ class AwqFusedTest(unittest.TestCase): multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq" multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442" + awq_rope_model_name = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" + awq_rope_model_revision = "db1f81ad4b8c7e39777509fac66c652eb0a52f91" + prompt = ( "You're standing on the surface of the Earth. " "You walk one mile south, one mile west and one mile north. " @@ -314,6 +317,7 @@ class AwqFusedTest(unittest.TestCase): EXPECTED_GENERATION = prompt + "\n\nYou're at the center of a square." EXPECTED_GENERATION_CUSTOM_MODEL = "Hello,\n\nI have a problem with my 20" EXPECTED_GENERATION_MIXTRAL = prompt + " You're on the North Pole.\n\nThe" + EXPECTED_GENERATION_AWQ_ROPE = prompt + " [Note: You can't be in a city, and" def tearDown(self): gc.collect() @@ -513,6 +517,32 @@ def test_generation_mixtral_fused(self): outputs = model.generate(**inputs, max_new_tokens=12) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL) + @pytest.mark.flash_attn_test + @require_flash_attn + @require_torch_multi_gpu + @unittest.skipIf( + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, + "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", + ) + def test_generation_awq_rope_fused(self): + """ + Text generation test for AWQ model with special RoPE implementation (e.g. LLaMA3) + fused + """ + quantization_config = AwqConfig(bits=4, fuse_max_seq_len=1024, do_fuse=True) + model = AutoModelForCausalLM.from_pretrained( + self.awq_rope_model_name, + quantization_config=quantization_config, + device_map="auto", + revision=self.awq_rope_model_revision, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.awq_rope_model_name) + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device) + + outputs = model.generate(**inputs, max_new_tokens=12, do_sample=False) + self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_AWQ_ROPE) @slow @require_torch_accelerator From 6d2fa278948b13a7ecc95a963af31fd40f87e597 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Wed, 12 Nov 2025 14:49:51 +0800 Subject: [PATCH 4/5] fix ruff format --- src/transformers/integrations/awq.py | 5 ++++- tests/quantization/autoawq/test_awq.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index d06f503e03ad..b541083a571f 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -71,7 +71,7 @@ "mlp": ["gate_proj", "up_proj", "down_proj"], "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], "use_alibi": False, - } + }, } AWQ_SCALES_MAPPINGS = { @@ -105,6 +105,7 @@ class AWQRoPE(RoPE): device (`torch.device`): The device to put the module on. """ + def __init__(self, rope_type, head_dim, max_seq_len, config, device): rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] self.inv_freq, self.attention_scaling = rope_init_fn(config, device) @@ -324,9 +325,11 @@ def fuse_awq_modules(model, quantization_config): # Hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value old_quant_attention_fused_forward = QuantAttentionFused.forward + def new_quant_attention_fused_forward(self, *args, **kwargs): attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs) return attn_output, attention_weight + QuantAttentionFused.forward = new_quant_attention_fused_forward else: raise ValueError("Fusing is only supported for the AutoAWQ backend") diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index f8bfeae1818c..5770bdbed13a 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -544,6 +544,7 @@ def test_generation_awq_rope_fused(self): outputs = model.generate(**inputs, max_new_tokens=12, do_sample=False) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_AWQ_ROPE) + @slow @require_torch_accelerator @require_auto_awq From e8d3a21502f369a541deefc375c054a7d984a1f5 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Wed, 19 Nov 2025 15:00:51 +0800 Subject: [PATCH 5/5] propagate changes in modeling_llama --- src/transformers/models/apertus/modeling_apertus.py | 1 + src/transformers/models/arcee/modeling_arcee.py | 1 + src/transformers/models/aria/modeling_aria.py | 1 + src/transformers/models/bitnet/modeling_bitnet.py | 1 + src/transformers/models/cohere/modeling_cohere.py | 1 + src/transformers/models/csm/modeling_csm.py | 1 + src/transformers/models/deepseek_v2/modeling_deepseek_v2.py | 1 + src/transformers/models/deepseek_v3/modeling_deepseek_v3.py | 1 + src/transformers/models/diffllama/modeling_diffllama.py | 1 + src/transformers/models/emu3/modeling_emu3.py | 1 + src/transformers/models/ernie4_5/modeling_ernie4_5.py | 1 + src/transformers/models/glm/modeling_glm.py | 1 + src/transformers/models/glm4/modeling_glm4.py | 1 + src/transformers/models/glm4_moe/modeling_glm4_moe.py | 1 + src/transformers/models/helium/modeling_helium.py | 1 + .../models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py | 1 + .../models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + src/transformers/models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/seed_oss/modeling_seed_oss.py | 1 + 20 files changed, 20 insertions(+) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index e92e87a3c280..cfad78f246e7 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -416,6 +416,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 619e72b7a11b..d4b92e0dabf4 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -421,6 +421,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index e702077bf930..28bca32c06ce 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -747,6 +747,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index d3972946a203..fb18b003710a 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -420,6 +420,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 71eb4870fbf2..417f4799a6b1 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -453,6 +453,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 7d3f87b2953d..a7498c5a1fb5 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -752,6 +752,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index ae3b6c4431bf..80baa8944eb7 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -533,6 +533,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index cbb63c5216be..812382b0716c 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -618,6 +618,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index d82430b623e1..51ee0c8b0980 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -673,6 +673,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index e2d1b1c98535..a7171b13a62a 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1245,6 +1245,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 5658c7691c3c..00551f1835af 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -419,6 +419,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f72268465ece..b20023a87f4b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -437,6 +437,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 935a722fd1db..dafa4fffa7d0 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -441,6 +441,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 00afc27bf236..d46b367706d5 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -562,6 +562,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index a1d0a09e848f..1db9db8ef4d9 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -420,6 +420,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index e3a55c296f6f..4485fb3e94c9 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -445,6 +445,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index a1fded6bdf77..822adcf8d22c 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -523,6 +523,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6a3432c31d18..e2a260db0632 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -423,6 +423,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 7315661282c9..90ded36eb7a9 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -428,6 +428,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 7e645e3ce052..58b538866602 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -426,6 +426,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, **kwargs, )