Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ac5b6d5
add new examples
yiliu30 Nov 6, 2025
2a2e834
fix mxfp4 moe for qwen
yiliu30 Nov 6, 2025
f0f0e1d
add mxfp8
yiliu30 Oct 31, 2025
ed23ef7
rename test
yiliu30 Oct 31, 2025
edba1ee
fix
yiliu30 Oct 31, 2025
d3d13b8
clean
yiliu30 Oct 31, 2025
dd909e4
fix linear
yiliu30 Nov 6, 2025
d2ed6a7
fix gate_up proj match
yiliu30 Nov 6, 2025
790a720
add mxfp4
yiliu30 Nov 6, 2025
2ad5558
add recipes
yiliu30 Nov 10, 2025
553529a
fix qwen mxfp4
yiliu30 Nov 10, 2025
284c41e
add mxfp4 moe
yiliu30 Nov 10, 2025
936ec4e
fix skip layers
yiliu30 Nov 10, 2025
218f564
update example
yiliu30 Nov 10, 2025
347d680
clean code
yiliu30 Nov 10, 2025
7eb9974
add mxfp4-mxfp8-moe
yiliu30 Nov 11, 2025
493f2df
fix moe mxfp8
yiliu30 Nov 12, 2025
fbc04ae
fix
yiliu30 Nov 12, 2025
bb4d90c
fix
yiliu30 Nov 12, 2025
edd3e9e
add readme
yiliu30 Nov 12, 2025
5f799b8
fix
yiliu30 Nov 12, 2025
84f3dbe
update
yiliu30 Nov 12, 2025
c9dbac0
update
yiliu30 Nov 12, 2025
3d21a74
add gene
yiliu30 Nov 12, 2025
01665c9
update
yiliu30 Nov 12, 2025
ebe9d79
update
yiliu30 Nov 12, 2025
7b986e3
update
yiliu30 Nov 12, 2025
efd3b1d
fix
yiliu30 Nov 12, 2025
e5044b4
format
yiliu30 Nov 12, 2025
68424c5
update
yiliu30 Nov 12, 2025
55a4e52
update example
yiliu30 Nov 12, 2025
59ff18d
correct mxfp8 usage
yiliu30 Nov 13, 2025
b8961e1
update example
yiliu30 Nov 14, 2025
3d89bb3
clean
yiliu30 Nov 14, 2025
6acd7ea
add eval cmd
yiliu30 Nov 14, 2025
6b720f0
remove examples
yiliu30 Nov 14, 2025
f880a1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2025
ed856fc
fix mxfp4
yiliu30 Nov 17, 2025
70ce2d0
add readme
yiliu30 Nov 17, 2025
919e954
fix
yiliu30 Nov 17, 2025
1cec73a
fix
yiliu30 Nov 17, 2025
3af54cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
121211a
Merge branch 'main' into more-ar-ext
yiliu30 Nov 17, 2025
d9f1d79
Merge branch 'main' into more-ar-ext
yiliu30 Nov 18, 2025
a29fd0a
add moe mxfp8
yiliu30 Nov 20, 2025
ecedb0c
Merge branch 'more-ar-ext' of https://github.com/intel/auto-round int…
yiliu30 Nov 20, 2025
2d37e63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions auto_round_extension/vllm_ext/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
- Build and Install vLLM

```
https://github.com/yiliu30/vllm-fork/tree/fused-moe-ar
VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv
```
- Apply vLLM-Ext Patches(allow python recognize them)
```
cd auto-round/auto_round_extension/vllm_ext
source apply_ext.sh
```

- Enable vLLM-Ext at Runtime
```bash
VLLM_ENABLE_AR_EXT=1 vllm serve ...
```
3 changes: 2 additions & 1 deletion auto_round_extension/vllm_ext/auto_round_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig

from auto_round.schemes import QuantizationScheme
from auto_round_extension.vllm_ext.quant_method_linear import AutoRoundQuantLinearMethod
from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod

logger = init_logger(__name__)
Expand All @@ -36,7 +37,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix)
return quant_method
elif isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
return AutoRoundQuantLinearMethod.get_method(self, layer, prefix)
else:
return None

Expand Down
2 changes: 2 additions & 0 deletions auto_round_extension/vllm_ext/envs_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
# Define extra environment variables
extra_environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"),
"VLLM_MXFP4_PRE_UNPACK_TO_FP8": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "0") in ("1", "true", "True"),
"VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"),
"VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"),
"VLLM_AR_POST_PROCESS_GPTOSS": lambda: os.getenv("VLLM_AR_POST_PROCESS_GPTOSS", "0") in ("1", "true", "True"),
}
# Add the extra environment variables to vllm.envs
import vllm.envs as envs
Expand Down
2 changes: 1 addition & 1 deletion auto_round_extension/vllm_ext/fp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}"
# assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}"

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)
Expand Down
134 changes: 134 additions & 0 deletions auto_round_extension/vllm_ext/linear_impl_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional

import torch
import vllm.envs as envs
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.parameter import GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter
from vllm.platforms import current_platform

from auto_round_extension.vllm_ext.mxfp4_qdq_utils import (
dequant_mxfp4_to_fp8,
mxfp4_gemm_with_unpacked_weight,
run_mxfp4_emulations,
)

logger = init_logger(__name__)

__all__ = ["AutoRoundMXFP4LinearImpl"]

from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl


class AutoRoundMXFP4LinearImpl(AutoRoundQuantImpl):
def __init__(self, quant_scheme):
self.quant_scheme = quant_scheme
self.group_size = 32

@classmethod
def get_min_capability(cls) -> int:
if envs.VLLM_USE_MXFP4_CT_EMULATIONS:
return 80
return 100

def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

# Weight
weight = ModelWeightParameter(
data=torch.empty(sum(output_partition_sizes), input_size_per_partition // 2, dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)

# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
# dtype=torch.uint8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)

layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer) -> None:
# FIXME: may dequant to bf16
if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
from auto_round_extension.vllm_ext.mxfp4_qdq_utils import (
dequant_mxfp4_to_fp8,
mxfp4_gemm_with_unpacked_weight,
run_mxfp4_emulations,
)

weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8(
data_lp=layer.weight_packed,
scale_e8m0=layer.weight_scale,
)
del layer.weight_packed
del layer.weight_scale
layer.weight_packed = None
layer.weight_scale = None
layer.register_parameter(
"weight_unpacked_fp8",
torch.nn.Parameter(
weight_fp8,
requires_grad=False,
),
)
layer.register_parameter(
"weight_scale_bf16",
torch.nn.Parameter(
scale_bf16,
requires_grad=False,
),
)

def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
if not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
out = run_mxfp4_emulations(x=x, weight=layer.weight_packed, weight_scale=layer.weight_scale)
if bias is not None:
out = out + bias
return out
else:
out = mxfp4_gemm_with_unpacked_weight(
x=x,
weight_fp8=layer.weight_unpacked_fp8,
weight_scale_bf16=layer.weight_scale_bf16,
bias=bias,
)
return out
112 changes: 112 additions & 0 deletions auto_round_extension/vllm_ext/linear_impl_mxfp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional

import torch
import vllm.envs as envs
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)

from auto_round_extension.vllm_ext.mxfp8_qdq_utils import dequant_mx_fp8, quant_mx_fp8
from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl


class AutoRoundMXFP8LinearImpl(AutoRoundQuantImpl):
def __init__(self, quant_scheme):
self.quant_scheme = quant_scheme
self.strategy = "TENSOR_GROUP"
self.out_dtype = torch.get_default_dtype()
self.group_size = 32

@classmethod
def get_min_capability(cls) -> int:
return 80

def process_weights_after_loading(self, layer) -> None:
return

def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
# maybe_create_device_identity()

output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes

# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# WEIGHT SCALE
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.uint8, # E8M0 for MXFP8 scale
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# dequant weight
weight = layer.weight
weight_scale = layer.weight_scale
dequnat_weight = dequant_mx_fp8(
Copy link

Copilot AI Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'dequnat_weight' to 'dequant_weight'.

Copilot uses AI. Check for mistakes.
weight_fp8=weight.data,
scale_e8m0=weight_scale.data,
block_size=self.group_size,
target_dtype=x.dtype,
)
dequnat_weight = dequnat_weight.to(x.dtype)
# if not envs.VLLM_AR_MXFP8_DISABLE_INPUT_QDQ:
# q-dq input
x_scale, x_quant = quant_mx_fp8(x)
dequant_x = dequant_mx_fp8(
weight_fp8=x_quant,
scale_e8m0=x_scale,
block_size=self.group_size,
target_dtype=x.dtype,
)
x = dequant_x.to(x.dtype)

out = x @ dequnat_weight.t()
return out.to(x.dtype) + (bias if bias is not None else 0)
Loading
Loading