-
Notifications
You must be signed in to change notification settings - Fork 60
Add MXFP8 MOE/Linear and MXFP4 Linear #1034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yiliu30
wants to merge
47
commits into
main
Choose a base branch
from
more-ar-ext
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,101
−67
Open
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
ac5b6d5
add new examples
yiliu30 2a2e834
fix mxfp4 moe for qwen
yiliu30 f0f0e1d
add mxfp8
yiliu30 ed23ef7
rename test
yiliu30 edba1ee
fix
yiliu30 d3d13b8
clean
yiliu30 dd909e4
fix linear
yiliu30 d2ed6a7
fix gate_up proj match
yiliu30 790a720
add mxfp4
yiliu30 2ad5558
add recipes
yiliu30 553529a
fix qwen mxfp4
yiliu30 284c41e
add mxfp4 moe
yiliu30 936ec4e
fix skip layers
yiliu30 218f564
update example
yiliu30 347d680
clean code
yiliu30 7eb9974
add mxfp4-mxfp8-moe
yiliu30 493f2df
fix moe mxfp8
yiliu30 fbc04ae
fix
yiliu30 bb4d90c
fix
yiliu30 edd3e9e
add readme
yiliu30 5f799b8
fix
yiliu30 84f3dbe
update
yiliu30 c9dbac0
update
yiliu30 3d21a74
add gene
yiliu30 01665c9
update
yiliu30 ebe9d79
update
yiliu30 7b986e3
update
yiliu30 efd3b1d
fix
yiliu30 e5044b4
format
yiliu30 68424c5
update
yiliu30 55a4e52
update example
yiliu30 59ff18d
correct mxfp8 usage
yiliu30 b8961e1
update example
yiliu30 3d89bb3
clean
yiliu30 6acd7ea
add eval cmd
yiliu30 6b720f0
remove examples
yiliu30 f880a1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ed856fc
fix mxfp4
yiliu30 70ce2d0
add readme
yiliu30 919e954
fix
yiliu30 1cec73a
fix
yiliu30 3af54cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 121211a
Merge branch 'main' into more-ar-ext
yiliu30 d9f1d79
Merge branch 'main' into more-ar-ext
yiliu30 a29fd0a
add moe mxfp8
yiliu30 ecedb0c
Merge branch 'more-ar-ext' of https://github.com/intel/auto-round int…
yiliu30 2d37e63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| - Build and Install vLLM | ||
|
|
||
| ``` | ||
| https://github.com/yiliu30/vllm-fork/tree/fused-moe-ar | ||
| VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv | ||
| ``` | ||
| - Apply vLLM-Ext Patches(allow python recognize them) | ||
| ``` | ||
| cd auto-round/auto_round_extension/vllm_ext | ||
| source apply_ext.sh | ||
| ``` | ||
|
|
||
| - Enable vLLM-Ext at Runtime | ||
| ```bash | ||
| VLLM_ENABLE_AR_EXT=1 vllm serve ... | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
| import vllm.envs as envs | ||
| from torch.nn.parameter import Parameter | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.parameter import GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( | ||
| dequant_mxfp4_to_fp8, | ||
| mxfp4_gemm_with_unpacked_weight, | ||
| run_mxfp4_emulations, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| __all__ = ["AutoRoundMXFP4LinearImpl"] | ||
|
|
||
| from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl | ||
|
|
||
|
|
||
| class AutoRoundMXFP4LinearImpl(AutoRoundQuantImpl): | ||
| def __init__(self, quant_scheme): | ||
| self.quant_scheme = quant_scheme | ||
| self.group_size = 32 | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| if envs.VLLM_USE_MXFP4_CT_EMULATIONS: | ||
| return 80 | ||
| return 100 | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| output_partition_sizes: list[int], | ||
| input_size_per_partition: int, | ||
| params_dtype: torch.dtype, | ||
| weight_loader: Callable, | ||
| **kwargs, | ||
| ): | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
| layer.logical_widths = output_partition_sizes | ||
| layer.input_size_per_partition = input_size_per_partition | ||
| layer.output_size_per_partition = output_size_per_partition | ||
|
|
||
| # Weight | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty(sum(output_partition_sizes), input_size_per_partition // 2, dtype=torch.uint8), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader, | ||
| ) | ||
| layer.register_parameter("weight_packed", weight) | ||
|
|
||
| # Per Group Weight Scale | ||
| weight_scale = GroupQuantScaleParameter( | ||
| data=torch.empty( | ||
| sum(output_partition_sizes), | ||
| input_size_per_partition // self.group_size, | ||
| # dtype=torch.uint8, | ||
| dtype=torch.uint8, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader, | ||
| ) | ||
|
|
||
| layer.register_parameter("weight_scale", weight_scale) | ||
|
|
||
| def process_weights_after_loading(self, layer) -> None: | ||
| # FIXME: may dequant to bf16 | ||
| if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: | ||
| from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( | ||
| dequant_mxfp4_to_fp8, | ||
| mxfp4_gemm_with_unpacked_weight, | ||
| run_mxfp4_emulations, | ||
| ) | ||
|
|
||
| weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( | ||
| data_lp=layer.weight_packed, | ||
| scale_e8m0=layer.weight_scale, | ||
| ) | ||
| del layer.weight_packed | ||
| del layer.weight_scale | ||
| layer.weight_packed = None | ||
| layer.weight_scale = None | ||
| layer.register_parameter( | ||
| "weight_unpacked_fp8", | ||
| torch.nn.Parameter( | ||
| weight_fp8, | ||
| requires_grad=False, | ||
| ), | ||
| ) | ||
| layer.register_parameter( | ||
| "weight_scale_bf16", | ||
| torch.nn.Parameter( | ||
| scale_bf16, | ||
| requires_grad=False, | ||
| ), | ||
| ) | ||
|
|
||
| def apply_weights( | ||
| self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None | ||
| ) -> torch.Tensor: | ||
| if not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: | ||
| out = run_mxfp4_emulations(x=x, weight=layer.weight_packed, weight_scale=layer.weight_scale) | ||
| if bias is not None: | ||
| out = out + bias | ||
| return out | ||
| else: | ||
| out = mxfp4_gemm_with_unpacked_weight( | ||
| x=x, | ||
| weight_fp8=layer.weight_unpacked_fp8, | ||
| weight_scale_bf16=layer.weight_scale_bf16, | ||
| bias=bias, | ||
| ) | ||
| return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
| import vllm.envs as envs | ||
| from vllm.model_executor.parameter import ( | ||
| GroupQuantScaleParameter, | ||
| ModelWeightParameter, | ||
| PerTensorScaleParameter, | ||
| ) | ||
|
|
||
| from auto_round_extension.vllm_ext.mxfp8_qdq_utils import dequant_mx_fp8, quant_mx_fp8 | ||
| from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl | ||
|
|
||
|
|
||
| class AutoRoundMXFP8LinearImpl(AutoRoundQuantImpl): | ||
| def __init__(self, quant_scheme): | ||
| self.quant_scheme = quant_scheme | ||
| self.strategy = "TENSOR_GROUP" | ||
| self.out_dtype = torch.get_default_dtype() | ||
| self.group_size = 32 | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| return 80 | ||
|
|
||
| def process_weights_after_loading(self, layer) -> None: | ||
| return | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| output_partition_sizes: list[int], | ||
| input_size_per_partition: int, | ||
| params_dtype: torch.dtype, | ||
| weight_loader: Callable, | ||
| **kwargs, | ||
| ): | ||
| # maybe_create_device_identity() | ||
|
|
||
| output_size_per_partition = sum(output_partition_sizes) | ||
| layer.logical_widths = output_partition_sizes | ||
|
|
||
| # WEIGHT | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty( | ||
| output_size_per_partition, | ||
| input_size_per_partition, | ||
| dtype=torch.float8_e4m3fn, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader, | ||
| ) | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| # WEIGHT SCALE | ||
| # Per Group Weight Scale | ||
| weight_scale = GroupQuantScaleParameter( | ||
| data=torch.empty( | ||
| sum(output_partition_sizes), | ||
| input_size_per_partition // self.group_size, | ||
| dtype=torch.uint8, # E8M0 for MXFP8 scale | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader, | ||
| ) | ||
| layer.register_parameter("weight_scale", weight_scale) | ||
|
|
||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| # dequant weight | ||
| weight = layer.weight | ||
| weight_scale = layer.weight_scale | ||
| dequnat_weight = dequant_mx_fp8( | ||
| weight_fp8=weight.data, | ||
| scale_e8m0=weight_scale.data, | ||
| block_size=self.group_size, | ||
| target_dtype=x.dtype, | ||
| ) | ||
| dequnat_weight = dequnat_weight.to(x.dtype) | ||
| # if not envs.VLLM_AR_MXFP8_DISABLE_INPUT_QDQ: | ||
| # q-dq input | ||
| x_scale, x_quant = quant_mx_fp8(x) | ||
| dequant_x = dequant_mx_fp8( | ||
| weight_fp8=x_quant, | ||
| scale_e8m0=x_scale, | ||
| block_size=self.group_size, | ||
| target_dtype=x.dtype, | ||
| ) | ||
| x = dequant_x.to(x.dtype) | ||
|
|
||
| out = x @ dequnat_weight.t() | ||
| return out.to(x.dtype) + (bias if bias is not None else 0) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'dequnat_weight' to 'dequant_weight'.