Skip to content

Commit db802aa

Browse files
BakerBunkerlvyuanjun.lyj
andauthored
Modify Qwen3Omni parameter name since VL changed it (#41045)
Modify parameter name since VL changed it Co-authored-by: lvyuanjun.lyj <[email protected]>
1 parent 8a2f24a commit db802aa

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,9 +1722,8 @@ def forward(
17221722
past_key_values=past_key_values,
17231723
)
17241724

1725-
def _deepstack_process(
1726-
self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
1727-
):
1725+
def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds):
1726+
visual_pos_masks = visual_pos_masks[..., 0]
17281727
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
17291728
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
17301729
local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
@@ -2151,7 +2150,7 @@ def forward(
21512150
use_cache=use_cache,
21522151
output_router_logits=output_router_logits,
21532152
cache_position=cache_position,
2154-
deepstack_visual_embeds_multiscale=visual_embeds_multiscale,
2153+
deepstack_visual_embeds=visual_embeds_multiscale,
21552154
visual_pos_masks=visual_pos_masks,
21562155
**kwargs,
21572156
)

src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,10 @@ def __init__(self, config: Qwen3OmniMoeTextConfig):
12281228
)
12291229
self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(config)
12301230

1231+
def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds):
1232+
visual_pos_masks = visual_pos_masks[..., 0]
1233+
return super()._deepstack_process(hidden_states, visual_pos_masks, visual_embeds)
1234+
12311235

12321236
@dataclass
12331237
class Qwen3OmniMoeThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
@@ -1408,7 +1412,7 @@ def forward(
14081412
use_cache=use_cache,
14091413
output_router_logits=output_router_logits,
14101414
cache_position=cache_position,
1411-
deepstack_visual_embeds_multiscale=visual_embeds_multiscale,
1415+
deepstack_visual_embeds=visual_embeds_multiscale,
14121416
visual_pos_masks=visual_pos_masks,
14131417
**kwargs,
14141418
)

0 commit comments

Comments
 (0)