From 937960bcc92ac807c5d73bb4000f02435d27c87f Mon Sep 17 00:00:00 2001 From: Crazyang Date: Sun, 16 Nov 2025 23:51:48 +0800 Subject: [PATCH] Add openpangu_moe model --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/openpangu_moe.md | 71 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/openpangu_moe/__init__.py | 29 + .../configuration_openpangu_moe.py | 81 +++ .../openpangu_moe/modeling_openpangu_moe.py | 610 ++++++++++++++++++ .../openpangu_moe/modular_openpangu_moe.py | 559 ++++++++++++++++ tests/models/openpangu_moe/__init__.py | 0 .../test_modeling_openpangu_moe.py | 577 +++++++++++++++++ 11 files changed, 1934 insertions(+) create mode 100644 docs/source/en/model_doc/openpangu_moe.md create mode 100644 src/transformers/models/openpangu_moe/__init__.py create mode 100644 src/transformers/models/openpangu_moe/configuration_openpangu_moe.py create mode 100644 src/transformers/models/openpangu_moe/modeling_openpangu_moe.py create mode 100644 src/transformers/models/openpangu_moe/modular_openpangu_moe.py create mode 100644 tests/models/openpangu_moe/__init__.py create mode 100644 tests/models/openpangu_moe/test_modeling_openpangu_moe.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 19180a7ef7f6..c46b16bb84f8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -648,6 +648,8 @@ title: OLMoE - local: model_doc/open-llama title: Open-Llama + - local: model_doc/openpangu_moe + title: OpenPanguMoE - local: model_doc/opt title: OPT - local: model_doc/pegasus diff --git a/docs/source/en/model_doc/openpangu_moe.md b/docs/source/en/model_doc/openpangu_moe.md new file mode 100644 index 000000000000..fc558c3bfc21 --- /dev/null +++ b/docs/source/en/model_doc/openpangu_moe.md @@ -0,0 +1,71 @@ + + + +# OpenPanguMoE + +## Overview + +The OpenPanguMoE model was proposed in []() by . + + +The abstract from the paper is the following: + + + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + +## Usage examples + + + +## OpenpanguMoeConfig + +[[autodoc]] OpenpanguMoeConfig + +## OpenpanguMoeForCausalLM + +[[autodoc]] OpenpanguMoeForCausalLM + +## OpenpanguMoeModel + +[[autodoc]] OpenpanguMoeModel + - forward + +## OpenpanguMoePreTrainedModel + +[[autodoc]] OpenpanguMoePreTrainedModel + - forward + +## OpenpanguMoeForSequenceClassification + +[[autodoc]] OpenpanguMoeForSequenceClassification + +## OpenpanguMoeForQuestionAnswering + +[[autodoc]] OpenpanguMoeForQuestionAnswering + +## OpenpanguMoeForTokenClassification + +[[autodoc]] OpenpanguMoeForTokenClassification \ No newline at end of file diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 76b7a9a32ac6..68d1a819f5c6 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -256,6 +256,7 @@ from .omdet_turbo import * from .oneformer import * from .openai import * + from .openpangu_moe import * from .opt import * from .ovis2 import * from .owlv2 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e17a41263504..0b9fda57dbac 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -301,6 +301,7 @@ ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), ("openai-gpt", "OpenAIGPTConfig"), + ("openpangu_moe", "OpenPanguMoEConfig"), ("opt", "OPTConfig"), ("ovis2", "Ovis2Config"), ("owlv2", "Owlv2Config"), @@ -763,6 +764,7 @@ ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), ("openai-gpt", "OpenAI GPT"), + ("openpangu_moe", "OpenPanguMoE"), ("opt", "OPT"), ("ovis2", "Ovis2"), ("owlv2", "OWLv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index acdf6c9db280..73f29bb9a256 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -300,6 +300,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), ("openai-gpt", "OpenAIGPTModel"), + ("openpangu_moe", "OpenPanguMoEModel"), ("opt", "OPTModel"), ("ovis2", "Ovis2Model"), ("owlv2", "Owlv2Model"), @@ -733,6 +734,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("openpangu_moe", "OpenPanguMoEForCausalLM"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), ("persimmon", "PersimmonForCausalLM"), diff --git a/src/transformers/models/openpangu_moe/__init__.py b/src/transformers/models/openpangu_moe/__init__.py new file mode 100644 index 000000000000..2cdfa0540017 --- /dev/null +++ b/src/transformers/models/openpangu_moe/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_openpangu_moe import * + from .modeling_openpangu_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/openpangu_moe/configuration_openpangu_moe.py b/src/transformers/models/openpangu_moe/configuration_openpangu_moe.py new file mode 100644 index 000000000000..21d2f12f50ab --- /dev/null +++ b/src/transformers/models/openpangu_moe/configuration_openpangu_moe.py @@ -0,0 +1,81 @@ +# coding=utf-8 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + +"""openPanguUltraMoE 718B model configuration""" + +from ...configuration_utils import PreTrainedConfig + +class OpenPanguMoEConfig(PreTrainedConfig): + + model_type = "pangu_ultra_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=153600, + hidden_size=7680, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_mtp_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + num_shared_experts=1, + num_routed_experts=256, + routed_scaling_factor=2.5, + attention_kv_lora_dim=512, + attention_q_lora_dim=1536, + attention_qk_rope_dim=64, + attention_v_dim=128, + attention_qk_dim=128, + num_experts_per_tok=8, + num_dense_layers=3, + norm_topk_prob=True, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + rope_theta=25600000, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + + self.num_dense_layers = num_dense_layers + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_shared_experts = num_shared_experts + self.num_routed_experts = num_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + self.attention_kv_lora_dim = attention_kv_lora_dim + self.attention_q_lora_dim = attention_q_lora_dim + self.attention_qk_rope_dim = attention_qk_rope_dim + self.attention_v_dim = attention_v_dim + self.attention_qk_dim = attention_qk_dim + self.attention_dropout = attention_dropout + self.num_mtp_layers = num_mtp_layers + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/src/transformers/models/openpangu_moe/modeling_openpangu_moe.py b/src/transformers/models/openpangu_moe/modeling_openpangu_moe.py new file mode 100644 index 000000000000..dc005e069275 --- /dev/null +++ b/src/transformers/models/openpangu_moe/modeling_openpangu_moe.py @@ -0,0 +1,610 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/openpangu_moe/modular_openpangu_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_openpangu_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from .configuration_openpangu_moe import OpenPanguMoEConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class OpenPanguMoERMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + OpenPanguMoERMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class OpenPanguMoERotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=131072, base=25600000.0, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self._set_cache( + seq_len=max_position_embeddings, + device=device, + dtype=torch.get_default_dtype(), + ) + + def _set_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, kv_len, max_seq_len=None): + if max_seq_len is None: + self._set_cache(seq_len=kv_len, device=x.device, dtype=x.dtype) + elif max_seq_len > self.max_seq_len_cached: + self._set_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype) + + batch_size = x.shape[0] + seq_len = x.shape[1] + if seq_len == 1: + cos = torch.index_select(self.cos_cached, dim=0, index=kv_len).unsqueeze(1).unsqueeze(1) + sin = torch.index_select(self.sin_cached, dim=0, index=kv_len).unsqueeze(1).unsqueeze(1) + else: + cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, 1, 1) + sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, 1, 1) + + cos = cos[0, :, 0, :] + sin = sin[0, :, 0, :] + return ( + cos.to(dtype=x.dtype), + sin.to(dtype=x.dtype), + ) + + +class OpenPanguMoEMLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class OpenPanguMoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.routed_scaling_factor = config.routed_scaling_factor + + self.norm_topk_prob = config.norm_topk_prob + self.weight = nn.Parameter(torch.empty((config.num_routed_experts, config.hidden_size))) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.to(torch.float32), self.weight.to(torch.float32), None) + scores = logits.sigmoid() + scores_for_choice = scores.view(bsz * seq_len, -1) + _, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor + + return topk_idx, topk_weight + + +class OpenPanguMoE(nn.Module): + def __init__(self, config): + super().__init__() + self.num_shared_experts = config.num_shared_experts + self.num_routed_experts = config.num_routed_experts + self.experts = nn.ModuleList( + [ + OpenPanguMoEMLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(self.num_routed_experts) + ] + ) + self.gate = OpenPanguMoEGate(config) + if self.num_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * self.num_shared_experts + self.shared_experts = OpenPanguMoEMLP(config=config, intermediate_size=intermediate_size) + + def forward(self, hidden_states): + if self.num_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + input_shape = hidden_states.shape + topk_ids, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + counts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + counts.scatter_(1, topk_ids, 1) + tokens_per_expert = counts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = hidden_states[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + output_hidden_states = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + output_hidden_states.append(expert_out) + start_idx = end_idx + + if len(output_hidden_states) > 0: + cat_hidden_states = torch.cat(output_hidden_states, dim=0) + else: + cat_hidden_states = sorted_tokens.new_empty(0) + + final_hidden_states = torch.empty_like(cat_hidden_states) + final_hidden_states[idxs] = cat_hidden_states + final_out = final_hidden_states.view(*topk_ids.shape, -1).to(topk_weight.dtype) + final_out = (final_out.mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).to(final_hidden_states.dtype)).view( + *input_shape + ) + if self.num_shared_experts is not None: + final_out = final_out + shared_output + return final_out + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class OpenPanguMoEAttention(nn.Module): + def __init__(self, config: OpenPanguMoEConfig, layer_idx: Optional[int] = None): + super().__init__() + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.attention_q_lora_dim = config.attention_q_lora_dim + self.attention_qk_rope_dim = config.attention_qk_rope_dim + self.attention_kv_lora_dim = config.attention_kv_lora_dim + self.attention_v_dim = config.attention_v_dim + self.attention_qk_dim = config.attention_qk_dim + self.q_head_dim = config.attention_qk_dim + config.attention_qk_rope_dim + + if self.attention_q_lora_dim is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(self.hidden_size, config.attention_q_lora_dim, bias=False) + self.q_a_layernorm = OpenPanguMoERMSNorm(config.attention_q_lora_dim) + self.q_b_proj = nn.Linear( + config.attention_q_lora_dim, + self.num_heads * self.q_head_dim, + bias=False, + ) + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.attention_kv_lora_dim + config.attention_qk_rope_dim, + bias=False, + ) + self.kv_a_layernorm = OpenPanguMoERMSNorm(config.attention_kv_lora_dim) + self.kv_b_proj = nn.Linear( + config.attention_kv_lora_dim, + self.num_heads * (config.attention_qk_dim + self.attention_v_dim), + bias=False, + ) + self.o_proj = nn.Linear( + self.num_heads * self.attention_v_dim, + self.hidden_size, + bias=False, + ) + self.rotary_emb = OpenPanguMoERotaryEmbedding( + self.attention_qk_rope_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + self.softmax_scale = self.q_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len, _ = hidden_states.size() + + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.attention_qk_dim, self.attention_qk_rope_dim], dim=-1) + + latent_kv = self.kv_a_proj_with_mqa(hidden_states) + kv_a, k_pe = torch.split(latent_kv, [self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.attention_qk_rope_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(kv_a)) + .view(bsz, q_len, self.num_heads, self.attention_qk_dim + self.attention_v_dim) + .transpose(1, 2) + ) + kv_seq_len = kv.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(kv, kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + k_nope, value = torch.split(kv, [self.attention_qk_dim, self.attention_v_dim], dim=-1) + + def concat_nope_pe(nope, pe): + states = torch.empty( + [bsz, self.num_heads, q_len, self.q_head_dim], + dtype=nope.dtype, + device=nope.device, + ) + states[:, :, :, : self.attention_qk_dim] = nope + states[:, :, :, self.attention_qk_dim :] = pe + return states + + query = concat_nope_pe(q_nope, q_pe) + key = concat_nope_pe(k_nope, k_pe) + + if past_key_value is not None: + key, value = past_key_value.update(key, value, self.layer_idx, {"sin": sin, "cos": cos}) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * self.softmax_scale + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + + +class OpenPanguMoEDecoderLayer(nn.Module): + def __init__(self, config: OpenPanguMoEConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = OpenPanguMoEAttention(config=config, layer_idx=layer_idx) + + self.mlp = ( + OpenPanguMoE(config) + if (config.num_routed_experts is not None and layer_idx >= config.num_dense_layers) + else OpenPanguMoEMLP(config) + ) + self.input_layernorm = OpenPanguMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OpenPanguMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if getattr(config, "sandwich_norm", False): + self.sandwich_norm = True + self.pre_mlp_layernorm = OpenPanguMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = OpenPanguMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.sandwich_norm = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + if self.sandwich_norm: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.pre_mlp_layernorm(hidden_states) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + if self.sandwich_norm: + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return (hidden_states, present_key_value) + + +@auto_docstring +class OpenPanguMoEPreTrainedModel(PreTrainedModel): + config: OpenPanguMoEConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OpenPanguMoEDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": OpenPanguMoEDecoderLayer, + "attentions": OpenPanguMoEAttention, + } + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + self._initialize_linear(module, std) + self._initialize_embedding(module, std) + + def _initialize_linear(self, module, std): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def _initialize_embedding(self, module, std): + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class OpenPanguMoEModel(OpenPanguMoEPreTrainedModel): + def __init__(self, config: OpenPanguMoEConfig): + super().__init__(config) + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.padding_idx = config.pad_token_id + self.layer_num = config.num_hidden_layers + self.epsilon = config.rms_norm_eps + + self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([OpenPanguMoEDecoderLayer(config, idx) for idx in range(self.layer_num)]) + self.norm = OpenPanguMoERMSNorm(self.hidden_size, eps=self.epsilon) + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPast]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You have to specify input_ids or inputs_embeds.") + + if input_ids is not None: + hidden_states = self.embed_tokens(input_ids) + batch_size, seq_length = input_ids.size() + else: + hidden_states = inputs_embeds + batch_size, seq_length = inputs_embeds.size() + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).unsqueeze(0) + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + position_ids += past_key_values_length + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + + for decoder_layer in self.layers: + hidden_states, present_key_value = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + ) + + hidden_states = self.norm(hidden_states) + + if use_cache and use_legacy_cache: + present_key_value = present_key_value.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=present_key_value, + ) + + +@auto_docstring +class OpenPanguMoEForCausalLM(OpenPanguMoEPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = OpenPanguMoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, OpenPanguMoEForCausalLM + + >>> model = OpenPanguMoEForCausalLM.from_pretrained("meta-open_pangu_mo_e/OpenPanguMoE-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-open_pangu_mo_e/OpenPanguMoE-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["OpenPanguMoEForCausalLM", "OpenPanguMoEModel", "OpenPanguMoEPreTrainedModel"] diff --git a/src/transformers/models/openpangu_moe/modular_openpangu_moe.py b/src/transformers/models/openpangu_moe/modular_openpangu_moe.py new file mode 100644 index 000000000000..f65c08a2a5b9 --- /dev/null +++ b/src/transformers/models/openpangu_moe/modular_openpangu_moe.py @@ -0,0 +1,559 @@ +# coding=utf-8 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from typing import Optional, Tuple, List, Union + +from ...cache_utils import Cache, DynamicCache +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_outputs import BaseModelOutputWithPast +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ...processing_utils import Unpack + + +from ..llama.modeling_llama import ( + LlamaForCausalLM, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + rotate_half, +) + +from .configuration_openpangu_moe import OpenPanguMoEConfig + +logger = logging.get_logger(__name__) + +class OpenPanguMoERMSNorm(LlamaRMSNorm): + pass + +class OpenPanguMoERotaryEmbedding(nn.Module): + def __init__( + self, dim, max_position_embeddings=131072, base=25600000.0, device=None + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self._set_cache( + seq_len=max_position_embeddings, + device=device, + dtype=torch.get_default_dtype(), + ) + + def _set_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, kv_len, max_seq_len=None): + if max_seq_len is None: + self._set_cache(seq_len=kv_len, device=x.device, dtype=x.dtype) + elif max_seq_len > self.max_seq_len_cached: + self._set_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype) + + batch_size = x.shape[0] + seq_len = x.shape[1] + if seq_len == 1: + cos = ( + torch.index_select(self.cos_cached, dim=0, index=kv_len) + .unsqueeze(1) + .unsqueeze(1) + ) + sin = ( + torch.index_select(self.sin_cached, dim=0, index=kv_len) + .unsqueeze(1) + .unsqueeze(1) + ) + else: + cos = ( + self.cos_cached[:seq_len] + .unsqueeze(0) + .unsqueeze(2) + .repeat(batch_size, 1, 1, 1) + ) + sin = ( + self.sin_cached[:seq_len] + .unsqueeze(0) + .unsqueeze(2) + .repeat(batch_size, 1, 1, 1) + ) + + cos = cos[0, :, 0, :] + sin = sin[0, :, 0, :] + return ( + cos.to(dtype=x.dtype), + sin.to(dtype=x.dtype), + ) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class OpenPanguMoEMLP(LlamaMLP): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__(config) + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + +class OpenPanguMoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.routed_scaling_factor = config.routed_scaling_factor + + self.norm_topk_prob = config.norm_topk_prob + self.weight = nn.Parameter( + torch.empty((config.num_routed_experts, config.hidden_size)) + ) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.to(torch.float32), self.weight.to(torch.float32), None + ) + scores = logits.sigmoid() + scores_for_choice = scores.view(bsz * seq_len, -1) + _, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor + + return topk_idx, topk_weight + +class OpenPanguMoE(nn.Module): + def __init__(self, config): + super().__init__() + self.num_shared_experts = config.num_shared_experts + self.num_routed_experts = config.num_routed_experts + self.experts = nn.ModuleList( + [ + OpenPanguMoEMLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(self.num_routed_experts) + ] + ) + self.gate = OpenPanguMoEGate(config) + if self.num_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * self.num_shared_experts + self.shared_experts = OpenPanguMoEMLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + if self.num_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + input_shape = hidden_states.shape + topk_ids, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + counts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + counts.scatter_(1, topk_ids, 1) + tokens_per_expert = counts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = hidden_states[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + output_hidden_states = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + output_hidden_states.append(expert_out) + start_idx = end_idx + + if len(output_hidden_states) > 0: + cat_hidden_states = torch.cat(output_hidden_states, dim=0) + else: + cat_hidden_states = sorted_tokens.new_empty(0) + + final_hidden_states = torch.empty_like(cat_hidden_states) + final_hidden_states[idxs] = cat_hidden_states + final_out = final_hidden_states.view(*topk_ids.shape, -1).to(topk_weight.dtype) + final_out = ( + final_out.mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .to(final_hidden_states.dtype) + ).view(*input_shape) + if self.num_shared_experts is not None: + final_out = final_out + shared_output + return final_out + +class OpenPanguMoEAttention(nn.Module): + def __init__(self, config: OpenPanguMoEConfig, layer_idx: Optional[int] = None): + super().__init__() + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.attention_q_lora_dim = config.attention_q_lora_dim + self.attention_qk_rope_dim = config.attention_qk_rope_dim + self.attention_kv_lora_dim = config.attention_kv_lora_dim + self.attention_v_dim = config.attention_v_dim + self.attention_qk_dim = config.attention_qk_dim + self.q_head_dim = config.attention_qk_dim + config.attention_qk_rope_dim + + if self.attention_q_lora_dim is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.attention_q_lora_dim, bias=False + ) + self.q_a_layernorm = OpenPanguMoERMSNorm(config.attention_q_lora_dim) + self.q_b_proj = nn.Linear( + config.attention_q_lora_dim, + self.num_heads * self.q_head_dim, + bias=False, + ) + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.attention_kv_lora_dim + config.attention_qk_rope_dim, + bias=False, + ) + self.kv_a_layernorm = OpenPanguMoERMSNorm(config.attention_kv_lora_dim) + self.kv_b_proj = nn.Linear( + config.attention_kv_lora_dim, + self.num_heads * (config.attention_qk_dim + self.attention_v_dim), + bias=False, + ) + self.o_proj = nn.Linear( + self.num_heads * self.attention_v_dim, + self.hidden_size, + bias=False, + ) + self.rotary_emb = OpenPanguMoERotaryEmbedding( + self.attention_qk_rope_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + self.softmax_scale = self.q_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len, _ = hidden_states.size() + + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.attention_qk_dim, self.attention_qk_rope_dim], dim=-1 + ) + + latent_kv = self.kv_a_proj_with_mqa(hidden_states) + kv_a, k_pe = torch.split( + latent_kv, [self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.attention_qk_rope_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(kv_a)) + .view( + bsz, q_len, self.num_heads, self.attention_qk_dim + self.attention_v_dim + ) + .transpose(1, 2) + ) + kv_seq_len = kv.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(kv, kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + k_nope, value = torch.split( + kv, [self.attention_qk_dim, self.attention_v_dim], dim=-1 + ) + + def concat_nope_pe(nope, pe): + states = torch.empty( + [bsz, self.num_heads, q_len, self.q_head_dim], + dtype=nope.dtype, + device=nope.device, + ) + states[:, :, :, : self.attention_qk_dim] = nope + states[:, :, :, self.attention_qk_dim :] = pe + return states + + query = concat_nope_pe(q_nope, q_pe) + key = concat_nope_pe(k_nope, k_pe) + + if past_key_value is not None: + key, value = past_key_value.update( + key, value, self.layer_idx, {"sin": sin, "cos": cos} + ) + + attn_weights = ( + torch.matmul(query, key.transpose(2, 3)) * self.softmax_scale + + attention_mask + ) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + +class OpenPanguMoEDecoderLayer(nn.Module): + def __init__(self, config: OpenPanguMoEConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = OpenPanguMoEAttention(config=config, layer_idx=layer_idx) + + self.mlp = ( + OpenPanguMoE(config) + if ( + config.num_routed_experts is not None + and layer_idx >= config.num_dense_layers + ) + else OpenPanguMoEMLP(config) + ) + self.input_layernorm = OpenPanguMoERMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = OpenPanguMoERMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + if getattr(config, "sandwich_norm", False): + self.sandwich_norm = True + self.pre_mlp_layernorm = OpenPanguMoERMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = OpenPanguMoERMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.sandwich_norm = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + if self.sandwich_norm: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.pre_mlp_layernorm(hidden_states) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + if self.sandwich_norm: + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return (hidden_states, present_key_value) + +class OpenPanguMoEPreTrainedModel(LlamaPreTrainedModel): + _supports_cache_class = True + _can_compile_fullgraph = False + + def _init_weights(self, module): + std = self.config.initializer_range + self._initialize_linear(module, std) + self._initialize_embedding(module, std) + + def _initialize_linear(self, module, std): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def _initialize_embedding(self, module, std): + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + +class OpenPanguMoEModel(OpenPanguMoEPreTrainedModel): + def __init__(self, config: OpenPanguMoEConfig): + super().__init__(config) + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.padding_idx = config.pad_token_id + self.layer_num = config.num_hidden_layers + self.epsilon = config.rms_norm_eps + + self.embed_tokens = nn.Embedding( + self.vocab_size, self.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [OpenPanguMoEDecoderLayer(config, idx) for idx in range(self.layer_num)] + ) + self.norm = OpenPanguMoERMSNorm(self.hidden_size, eps=self.epsilon) + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You have to specify input_ids or inputs_embeds.") + + if input_ids is not None: + hidden_states = self.embed_tokens(input_ids) + batch_size, seq_length = input_ids.size() + else: + hidden_states = inputs_embeds + batch_size, seq_length = inputs_embeds.size() + + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).unsqueeze(0) + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + position_ids += past_key_values_length + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + + for decoder_layer in self.layers: + hidden_states, present_key_value = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + ) + + hidden_states = self.norm(hidden_states) + + if use_cache and use_legacy_cache: + present_key_value = present_key_value.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=present_key_value, + ) + +class OpenPanguMoEForCausalLM(LlamaForCausalLM): + pass + + +__all__ = [ + "OpenPanguMoEForCausalLM", + "OpenPanguMoEModel", + "OpenPanguMoEPreTrainedModel", +] \ No newline at end of file diff --git a/tests/models/openpangu_moe/__init__.py b/tests/models/openpangu_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/openpangu_moe/test_modeling_openpangu_moe.py b/tests/models/openpangu_moe/test_modeling_openpangu_moe.py new file mode 100644 index 000000000000..7f240025cd87 --- /dev/null +++ b/tests/models/openpangu_moe/test_modeling_openpangu_moe.py @@ -0,0 +1,577 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch LLaMA model.""" + +import unittest + +import pytest +from packaging import version + +from transformers import AutoTokenizer, StaticCache, is_torch_available +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + Expectations, + cleanup, + require_read_token, + require_torch, + require_torch_accelerator, + run_test_using_subprocess, + slow, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +if is_torch_available(): + import torch + + from transformers import ( + OpenpanguMoeForCausalLM, + OpenpanguMoeModel, + OpenpanguMoeTokenizer, + ) + + +class OpenpanguMoeModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = OpenpanguMoeModel + + +@require_torch +class OpenpanguMoeModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = OpenpanguMoeModelTester + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = OpenpanguMoeForCausalLM if is_torch_available() else None + + +@require_torch_accelerator +@require_read_token +class OpenpanguMoeIntegrationTest(unittest.TestCase): + def setup(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves + # some memory allocated in the cache, which means some object is not being released properly. This causes some + # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. + # Investigate the root cause. + cleanup(torch_device, gc_collect=True) + + @slow + def test_llama_3_1_hard(self): + """ + An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences + from llama 3.1.'s RoPE can be detected + """ + expected_texts = Expectations( + { + ("rocm", (9, 5)): 'Tell me about the french revolution. The french revolution was a period of radical social and political upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. This marked the beginning of the end of the absolute monarchy and the rise of the middle class.\n', + ("cuda", None): 'Tell me about the french revolution. The french revolution was a period of radical political and social upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative assembly that had not met since 1614. The Third Estate, which represented the common people, demanded greater representation and eventually broke away to form the National Assembly. The National Assembly adopted the Declaration of the Rights of Man and of the Citizen, which enshr', + } + ) # fmt: skip + EXPECTED_TEXT = expected_texts.get_expectation() + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-OpenpanguMoe-3.1-8B-Instruct") + model = OpenpanguMoeForCausalLM.from_pretrained( + "meta-llama/Meta-OpenpanguMoe-3.1-8B-Instruct", device_map="auto", dtype=torch.bfloat16 + ) + input_text = ["Tell me about the french revolution."] + model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(generated_text, EXPECTED_TEXT) + + @slow + def test_model_7b_logits_bf16(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = OpenpanguMoeForCausalLM.from_pretrained( + "meta-llama/OpenpanguMoe-2-7b-hf", device_map="auto", dtype=torch.bfloat16, attn_implementation="eager" + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + # Expected mean on dim = -1 + + # fmt: off + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), + ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), + ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), + ("rocm", (9, 4)): torch.tensor([[-6.5094, -4.1329, -4.9754, -3.5042, 0.8082, -2.9443, 1.2830, -3.3539]]), + }) + + expected_mean = expected_means.get_expectation().to(torch_device) + actual_mean = out.logits.float().mean(-1) + self.assertTrue( + torch.allclose( + expected_mean, + actual_mean, + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), + ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), + ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), + ("rocm", (9, 4)): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9375, -6.0312, -7.0312, -1.8594, 1.8438, -8.5000]]) + }) + # fmt: on + expected_slice = expected_slices.get_expectation().to(torch_device) + actual_slice = out.logits[0, 0, :15].float() + self.assertTrue(torch.allclose(expected_slice, actual_slice, atol=1e-2, rtol=1e-2)) + + @slow + def test_model_7b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = OpenpanguMoeForCausalLM.from_pretrained( + "meta-llama/OpenpanguMoe-2-7b-hf", device_map="auto", dtype=torch.float16 + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + + # fmt: off + # Expected mean on dim = -1 + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), + ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), + ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), + ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) + }) + # fmt: on + + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + # TODO joao, manuel: remove this in v4.62.0 + # TODO: check why we have the following strange situation. + # without running in subprocess, this test causes subsequent tests failing with `RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!` + @run_test_using_subprocess + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " + "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " + "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " + "understanding of space and time." + ) + prompt = "Simply put, the theory of relativity states that " + tokenizer = OpenpanguMoeTokenizer.from_pretrained("meta-llama/OpenpanguMoe-2-7b-chat-hf") + model = OpenpanguMoeForCausalLM.from_pretrained( + "meta-llama/OpenpanguMoe-2-7b-chat-hf", device_map="sequential", dtype=torch.float16 + ) + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate( + **model_inputs, + max_new_tokens=64, + top_p=None, + temperature=1, + do_sample=False, + dola_layers="low", + trust_remote_code=True, + custom_generate="transformers-community/dola", + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_torch_accelerator + @pytest.mark.torch_compile_test + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = OpenpanguMoeTokenizer.from_pretrained( + "meta-llama/OpenpanguMoe-2-7b-hf", pad_token="", padding_side="right" + ) + model = OpenpanguMoeForCausalLM.from_pretrained( + "meta-llama/OpenpanguMoe-2-7b-hf", device_map=torch_device, dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + @slow + @pytest.mark.torch_export_test + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + + llama_models = { + "meta-llama/OpenpanguMoe-3.2-1B": [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all " + "observers, regardless of their location, and 2) the laws of physics are the same for all observers" + ], + } + + for llama_model_ckp, EXPECTED_TEXT_COMPLETION in llama_models.items(): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(llama_model_ckp, pad_token="", padding_side="right") + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" # TODO (joao / export experts): should be on `torch_device`, but causes GPU OOM + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = OpenpanguMoeForCausalLM.from_pretrained( + llama_model_ckp, + device_map=device, + dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + "device": device, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), + ) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + +@slow +@require_torch_accelerator +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + model_name = "TinyOpenpanguMoe/TinyOpenpanguMoe-1.1B-Chat-v1.0" + self.model_dtype = torch.float32 + self.tokenizer = OpenpanguMoeTokenizer.from_pretrained(model_name) + self.model = OpenpanguMoeForCausalLM.from_pretrained(model_name, dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + + # building custom positions ids based on custom mask + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) + # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + + # inverting the mask + min_dtype = torch.finfo(self.model_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, + attention_mask=mask_1b, + position_ids=position_ids_1b, + past_key_values=past_key_values_a, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) + + def test_stacked_causal_mask_static_cache(self): + """same as above but with StaticCache""" + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len) + + padded_attention_mask = torch.nn.functional.pad( + input=mask_shared_prefix, + pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, + attention_mask=padded_attention_mask, + position_ids=position_ids_shared_prefix, + cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), + past_key_values=past_key_values, + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask_static_cache(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len) + + # forward run for the first part of input + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + padded_mask_1a = torch.nn.functional.pad( + input=mask_1a, + pad=(0, max_cache_len - mask_1a.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + _ = self.model.forward( + input_1a, + attention_mask=padded_mask_1a, + position_ids=position_ids_1a, + cache_position=torch.arange(part_a, device=torch_device), + past_key_values=past_key_values, + ) + + # forward run for the second part of input + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + + padded_mask_1b = torch.nn.functional.pad( + input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 + ) + + outs_1b = self.model.forward( + input_1b, + attention_mask=padded_mask_1b, + position_ids=position_ids_1b, + cache_position=torch.arange( + part_a, + input_ids_shared_prefix.shape[-1], + device=torch_device, + ), + past_key_values=past_key_values, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b)