|
| 1 | +# Copyright (c) 2025 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +import torch |
| 17 | +from torch import nn |
| 18 | +from transformers.modeling_utils import no_init_weights as skip_weights_initialize |
| 19 | +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig |
| 20 | +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP |
| 21 | + |
| 22 | +__all__ = ["get_replacement_info"] |
| 23 | + |
| 24 | + |
| 25 | +def _update_parameter( |
| 26 | + module: torch.nn.Module, |
| 27 | + name: str, |
| 28 | + data: torch.Tensor, |
| 29 | +) -> None: |
| 30 | + param = getattr(module, name) |
| 31 | + param.data.copy_(data) |
| 32 | + |
| 33 | + |
| 34 | +class GPTOssSingleExpert(nn.Module): |
| 35 | + def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype | None = None): |
| 36 | + super().__init__() |
| 37 | + self.hidden_size = hidden_size |
| 38 | + self.intermediate_size = intermediate_size |
| 39 | + self.alpha = 1.702 |
| 40 | + self.limit = 7.0 |
| 41 | + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) |
| 42 | + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) |
| 43 | + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype) |
| 44 | + |
| 45 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 46 | + gate = self.gate_proj(x) |
| 47 | + up = self.up_proj(x) |
| 48 | + gate = gate.clamp(max=self.limit) |
| 49 | + up = up.clamp(min=-self.limit, max=self.limit) |
| 50 | + glu = gate * torch.sigmoid(gate * self.alpha) |
| 51 | + act = (up + 1) * glu |
| 52 | + return self.down_proj(act) |
| 53 | + |
| 54 | + |
| 55 | +class SequentialGPTOSSMoE(nn.Module): |
| 56 | + """ |
| 57 | + Replaces GPT-OSS fused-expert MoE with per-expert `GPTOssSingleExpert` modules. |
| 58 | + Copies weights from fused tensors and reuses the original router and optional shared_expert. |
| 59 | + """ |
| 60 | + |
| 61 | + def __init__(self, config: GptOssConfig, original: GptOssMLP): |
| 62 | + super().__init__() |
| 63 | + hidden_size = config.hidden_size |
| 64 | + intermediate_size = config.intermediate_size |
| 65 | + dtype_str = getattr(config, "torch_dtype", None) or getattr(config, "dtype", None) |
| 66 | + dtype = torch.bfloat16 if str(dtype_str).endswith("bfloat16") else torch.float32 |
| 67 | + top_k = config.num_experts_per_tok |
| 68 | + self.hidden_size = hidden_size |
| 69 | + self.intermediate = intermediate_size |
| 70 | + self.top_k = top_k |
| 71 | + self.router = original.router |
| 72 | + self.shared_expert = getattr(original, "shared_expert", None) |
| 73 | + |
| 74 | + # Number of experts |
| 75 | + E = original.experts.gate_up_proj.shape[0] |
| 76 | + self.num_experts = E |
| 77 | + |
| 78 | + # Build per-expert MLPs |
| 79 | + self.experts = nn.ModuleList() |
| 80 | + target_device = next(original.experts.parameters()).device |
| 81 | + with skip_weights_initialize(), torch.device(target_device): |
| 82 | + for _ in range(E): |
| 83 | + self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype)) |
| 84 | + |
| 85 | + gup = original.experts.gate_up_proj # [E, H, 2I] |
| 86 | + gup_b = original.experts.gate_up_proj_bias # [E, 2I] |
| 87 | + dwn = original.experts.down_proj # [E, I, H] |
| 88 | + dwn_b = original.experts.down_proj_bias # [E, H] |
| 89 | + |
| 90 | + for i, mlp in enumerate(self.experts): |
| 91 | + _update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T) |
| 92 | + _update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T) |
| 93 | + _update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T) |
| 94 | + |
| 95 | + _update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2]) |
| 96 | + _update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2]) |
| 97 | + _update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H] |
| 98 | + |
| 99 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 100 | + B, T, H = hidden_states.shape |
| 101 | + x = hidden_states.reshape(-1, H) |
| 102 | + |
| 103 | + # Use the original router (it returns scores and indices already softmaxed over top-k) |
| 104 | + router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k] |
| 105 | + |
| 106 | + out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) |
| 107 | + |
| 108 | + # Accumulate expert outputs for chosen experts only |
| 109 | + for j in range(self.top_k): |
| 110 | + idx = router_indices[:, j] |
| 111 | + w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) |
| 112 | + unique_experts = torch.unique(idx) |
| 113 | + for e in unique_experts: |
| 114 | + mask = idx == e |
| 115 | + out[mask] += self.experts[e](x[mask]) * w[mask] |
| 116 | + |
| 117 | + out = out.view(B, T, H) |
| 118 | + router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder |
| 119 | + return out, router_scores |
| 120 | + |
| 121 | + |
| 122 | +def get_replacement_info(config): |
| 123 | + return ( |
| 124 | + SequentialGPTOSSMoE, |
| 125 | + config.get_text_config(), |
| 126 | + GptOssMLP.__name__, |
| 127 | + ) |
0 commit comments