Skip to content

Commit 081c92a

Browse files
authored
Add GPT-OSS quant support (#887)
1 parent db560e0 commit 081c92a

File tree

6 files changed

+306
-63
lines changed

6 files changed

+306
-63
lines changed

auto_round/modelling/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

auto_round/modelling/gpt_oss.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
from torch import nn
18+
from transformers.modeling_utils import no_init_weights as skip_weights_initialize
19+
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
20+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
21+
22+
__all__ = ["get_replacement_info"]
23+
24+
25+
def _update_parameter(
26+
module: torch.nn.Module,
27+
name: str,
28+
data: torch.Tensor,
29+
) -> None:
30+
param = getattr(module, name)
31+
param.data.copy_(data)
32+
33+
34+
class GPTOssSingleExpert(nn.Module):
35+
def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype | None = None):
36+
super().__init__()
37+
self.hidden_size = hidden_size
38+
self.intermediate_size = intermediate_size
39+
self.alpha = 1.702
40+
self.limit = 7.0
41+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype)
42+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype)
43+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype)
44+
45+
def forward(self, x: torch.Tensor) -> torch.Tensor:
46+
gate = self.gate_proj(x)
47+
up = self.up_proj(x)
48+
gate = gate.clamp(max=self.limit)
49+
up = up.clamp(min=-self.limit, max=self.limit)
50+
glu = gate * torch.sigmoid(gate * self.alpha)
51+
act = (up + 1) * glu
52+
return self.down_proj(act)
53+
54+
55+
class SequentialGPTOSSMoE(nn.Module):
56+
"""
57+
Replaces GPT-OSS fused-expert MoE with per-expert `GPTOssSingleExpert` modules.
58+
Copies weights from fused tensors and reuses the original router and optional shared_expert.
59+
"""
60+
61+
def __init__(self, config: GptOssConfig, original: GptOssMLP):
62+
super().__init__()
63+
hidden_size = config.hidden_size
64+
intermediate_size = config.intermediate_size
65+
dtype_str = getattr(config, "torch_dtype", None) or getattr(config, "dtype", None)
66+
dtype = torch.bfloat16 if str(dtype_str).endswith("bfloat16") else torch.float32
67+
top_k = config.num_experts_per_tok
68+
self.hidden_size = hidden_size
69+
self.intermediate = intermediate_size
70+
self.top_k = top_k
71+
self.router = original.router
72+
self.shared_expert = getattr(original, "shared_expert", None)
73+
74+
# Number of experts
75+
E = original.experts.gate_up_proj.shape[0]
76+
self.num_experts = E
77+
78+
# Build per-expert MLPs
79+
self.experts = nn.ModuleList()
80+
target_device = next(original.experts.parameters()).device
81+
with skip_weights_initialize(), torch.device(target_device):
82+
for _ in range(E):
83+
self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype))
84+
85+
gup = original.experts.gate_up_proj # [E, H, 2I]
86+
gup_b = original.experts.gate_up_proj_bias # [E, 2I]
87+
dwn = original.experts.down_proj # [E, I, H]
88+
dwn_b = original.experts.down_proj_bias # [E, H]
89+
90+
for i, mlp in enumerate(self.experts):
91+
_update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T)
92+
_update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T)
93+
_update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T)
94+
95+
_update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2])
96+
_update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2])
97+
_update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H]
98+
99+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
100+
B, T, H = hidden_states.shape
101+
x = hidden_states.reshape(-1, H)
102+
103+
# Use the original router (it returns scores and indices already softmaxed over top-k)
104+
router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k]
105+
106+
out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x)
107+
108+
# Accumulate expert outputs for chosen experts only
109+
for j in range(self.top_k):
110+
idx = router_indices[:, j]
111+
w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1)
112+
unique_experts = torch.unique(idx)
113+
for e in unique_experts:
114+
mask = idx == e
115+
out[mask] += self.experts[e](x[mask]) * w[mask]
116+
117+
out = out.view(B, T, H)
118+
router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder
119+
return out, router_scores
120+
121+
122+
def get_replacement_info(config):
123+
return (
124+
SequentialGPTOSSMoE,
125+
config.get_text_config(),
126+
GptOssMLP.__name__,
127+
)

auto_round/modelling/llama4.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# Note: adapted from # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py
15+
16+
__all__ = ["get_replacement_info"]
17+
18+
19+
import torch
20+
from transformers.modeling_utils import no_init_weights
21+
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
22+
23+
24+
class SequentialLlama4TextExperts(torch.nn.ModuleList):
25+
def __init__(self, config, original):
26+
self.num_experts = original.gate_up_proj.shape[0]
27+
with no_init_weights():
28+
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
29+
intermediate_size = original.down_proj.shape[1]
30+
31+
for i in range(self.num_experts):
32+
gate_up = original.gate_up_proj[i]
33+
down = original.down_proj[i]
34+
gate_proj = gate_up[:, :intermediate_size]
35+
up_proj = gate_up[:, intermediate_size:]
36+
37+
self[i].gate_proj.weight.data = gate_proj.t().contiguous()
38+
self[i].up_proj.weight.data = up_proj.t().contiguous()
39+
self[i].down_proj.weight.data = down.t().contiguous()
40+
41+
42+
class SequentialLlama4TextMoe(torch.nn.Module):
43+
def __init__(self, config, original):
44+
super().__init__()
45+
self.top_k = config.num_experts_per_tok
46+
self.hidden_dim = config.hidden_size
47+
self.num_experts = config.num_local_experts
48+
self.experts = SequentialLlama4TextExperts(config, original.experts)
49+
self.router = original.router
50+
self.shared_expert = original.shared_expert
51+
52+
def forward(self, hidden_states: torch.Tensor):
53+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
54+
router_logits = self.router(hidden_states)
55+
if isinstance(router_logits, tuple):
56+
router_scores, router_logits = router_logits
57+
router_scores = router_scores.t()
58+
else:
59+
# transformers < 4.54.0 only returns router_logits
60+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
61+
62+
router_scores = (
63+
torch.full_like(router_logits, float("-inf"))
64+
.scatter_(1, router_indices, router_top_value)
65+
.transpose(0, 1)
66+
)
67+
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
68+
69+
out = self.shared_expert(hidden_states)
70+
for i in range(self.num_experts):
71+
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
72+
73+
return out, router_logits
74+
75+
76+
def get_replacement_info(config):
77+
return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe"

auto_round/special_model_handler.py

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from auto_round.utils import logger
14+
import auto_round.modelling as auto_round_modelling
15+
from auto_round.utils import LazyImport, logger
1516

1617
mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size
1718

@@ -36,71 +37,24 @@
3637
}
3738
SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",)
3839

39-
CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4"]
40+
CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"]
4041

4142

4243
def _get_moe_converter(config):
43-
import torch
44-
from transformers.modeling_utils import no_init_weights
45-
46-
# https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py
47-
if config.model_type == "llama4":
48-
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
49-
50-
class SequentialLlama4TextExperts(torch.nn.ModuleList):
51-
def __init__(self, config, original):
52-
self.num_experts = original.gate_up_proj.shape[0]
53-
with no_init_weights():
54-
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
55-
intermediate_size = original.down_proj.shape[1]
56-
57-
for i in range(self.num_experts):
58-
gate_up = original.gate_up_proj[i]
59-
down = original.down_proj[i]
60-
gate_proj = gate_up[:, :intermediate_size]
61-
up_proj = gate_up[:, intermediate_size:]
62-
63-
self[i].gate_proj.weight.data = gate_proj.t().contiguous()
64-
self[i].up_proj.weight.data = up_proj.t().contiguous()
65-
self[i].down_proj.weight.data = down.t().contiguous()
66-
67-
class SequentialLlama4TextMoe(torch.nn.Module):
68-
def __init__(self, config, original):
69-
super().__init__()
70-
self.top_k = config.num_experts_per_tok
71-
self.hidden_dim = config.hidden_size
72-
self.num_experts = config.num_local_experts
73-
self.experts = SequentialLlama4TextExperts(config, original.experts)
74-
self.router = original.router
75-
self.shared_expert = original.shared_expert
76-
77-
def forward(self, hidden_states: torch.Tensor):
78-
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
79-
router_logits = self.router(hidden_states)
80-
if isinstance(router_logits, tuple):
81-
router_scores, router_logits = router_logits
82-
router_scores = router_scores.t()
83-
else:
84-
# transformers < 4.54.0 only returns router_logits
85-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
86-
87-
router_scores = (
88-
torch.full_like(router_logits, float("-inf"))
89-
.scatter_(1, router_indices, router_top_value)
90-
.transpose(0, 1)
91-
)
92-
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
93-
94-
out = self.shared_expert(hidden_states)
95-
for i in range(self.num_experts):
96-
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
97-
98-
return out, router_logits
99-
100-
return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe"
101-
44+
# Dispatch table for model_type to replacement_info functions
45+
moe_converters = {
46+
"gpt_oss": LazyImport("auto_round.modelling.gpt_oss.get_replacement_info"),
47+
"llama4": LazyImport("auto_round.modelling.llama4.get_replacement_info"),
48+
}
49+
50+
# Retrieve the appropriate function based on model_type
51+
if config.model_type in moe_converters:
52+
return moe_converters[config.model_type](config)
10253
else:
103-
raise ValueError(f"Currently moe converter only supports llama4 model_type, but get {config.model_type}")
54+
raise ValueError(
55+
f"Unsupported model_type '{config.model_type}'. "
56+
f"Currently, MoE converter only supports: {', '.join(moe_converters.keys())}."
57+
)
10458

10559

10660
def _handle_special_model(model):

auto_round/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def get_fp_layer_names(model, fp_layers):
11001100
for name in all_layer_names:
11011101
if fp_layer in name:
11021102
not_to_quantized_layers.append(name)
1103-
1103+
logger.trace(f"not_to_quantized_layers: {not_to_quantized_layers}")
11041104
return not_to_quantized_layers
11051105

11061106

test/test_cpu/test_gpt_oss.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
from transformers import AutoConfig, AutoTokenizer
3+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
4+
5+
from auto_round import AutoRound
6+
7+
8+
@pytest.fixture
9+
def setup_gpt_oss():
10+
"""Fixture to set up the GPT-OSS model and tokenizer."""
11+
model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16"
12+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
14+
config.num_hidden_layers = 1 # Reduce layers for testing
15+
model = GptOssForCausalLM(config)
16+
output_dir = "/tmp/test_quantized_gpt_oss"
17+
return model, tokenizer, output_dir, config
18+
19+
20+
def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
21+
"""Helper function to quantize the model with the given scheme."""
22+
autoround = AutoRound(
23+
model,
24+
tokenizer,
25+
scheme=scheme,
26+
nsamples=2,
27+
iters=iters,
28+
fp_layers="self_attn,router,lm_head,mlp.gate",
29+
)
30+
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
31+
return quantized_model
32+
33+
34+
def count_modules_by_type(model, target_module_name_or_class):
35+
"""Helper function to count modules of a specific type in the model."""
36+
cnt = 0
37+
for name, module in model.named_modules():
38+
if isinstance(target_module_name_or_class, str):
39+
if target_module_name_or_class == module.__class__.__name__:
40+
cnt += 1
41+
else:
42+
if isinstance(module, target_module_name_or_class):
43+
cnt += 1
44+
return cnt
45+
46+
47+
@pytest.mark.parametrize("scheme", ["MXFP4", "MXFP8"])
48+
def test_quantization(setup_gpt_oss, scheme):
49+
"""Test quantization with the scheme."""
50+
model, tokenizer, output_dir, config = setup_gpt_oss
51+
quantized_model = quantize_model(model, tokenizer, output_dir, scheme)
52+
53+
# Ensure the quantized model is not None
54+
assert quantized_model is not None, "Quantized model should not be None."
55+
from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear
56+
from auto_round.modelling.gpt_oss import GPTOssSingleExpert
57+
58+
single_expert_cnt = count_modules_by_type(quantized_model, GPTOssSingleExpert)
59+
quant_linear_cnt = count_modules_by_type(quantized_model, QuantLinear)
60+
assert (
61+
single_expert_cnt == config.num_local_experts
62+
), f"Expected {config.num_local_experts} GPTOssSingleExpert modules, found {single_expert_cnt}."
63+
assert (
64+
quant_linear_cnt == config.num_hidden_layers * 3 * config.num_local_experts
65+
), f"Expected {config.num_hidden_layers * 3 * config.num_local_experts} QuantLinear modules, found {quant_linear_cnt}."
66+
67+
print(f"[{scheme}] Total {GPTOssSingleExpert.__name__} modules: {single_expert_cnt}")
68+
print(f"[{scheme}] Total {QuantLinear.__name__} modules: {quant_linear_cnt}")
69+
# clean the output directory after test
70+
import shutil
71+
72+
shutil.rmtree(output_dir, ignore_errors=True)

0 commit comments

Comments
 (0)