diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 2ddc4d656530..1172ebf90919 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1722,9 +1722,8 @@ def forward( past_key_values=past_key_values, ) - def _deepstack_process( - self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor - ): + def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds): + visual_pos_masks = visual_pos_masks[..., 0] visual_pos_masks = visual_pos_masks.to(hidden_states.device) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds @@ -2151,7 +2150,7 @@ def forward( use_cache=use_cache, output_router_logits=output_router_logits, cache_position=cache_position, - deepstack_visual_embeds_multiscale=visual_embeds_multiscale, + deepstack_visual_embeds=visual_embeds_multiscale, visual_pos_masks=visual_pos_masks, **kwargs, ) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 8a7ba792f846..4d1c30f0a4c3 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1228,6 +1228,10 @@ def __init__(self, config: Qwen3OmniMoeTextConfig): ) self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(config) + def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds): + visual_pos_masks = visual_pos_masks[..., 0] + return super()._deepstack_process(hidden_states, visual_pos_masks, visual_embeds) + @dataclass class Qwen3OmniMoeThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast): @@ -1408,7 +1412,7 @@ def forward( use_cache=use_cache, output_router_logits=output_router_logits, cache_position=cache_position, - deepstack_visual_embeds_multiscale=visual_embeds_multiscale, + deepstack_visual_embeds=visual_embeds_multiscale, visual_pos_masks=visual_pos_masks, **kwargs, )