From fbefce5f7b7422a4045ab303a31323a81911e1d5 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 2 Oct 2025 07:15:27 +0000 Subject: [PATCH 1/4] Support PaddlePaddle with compatible API and tvm-ffi --- flashinfer/fp4_quantization.py | 9 +++++--- flashinfer/fused_moe/core.py | 41 +++++++++++++++++++++++++--------- flashinfer/jit/core.py | 4 ++++ flashinfer/jit/cpp_ext.py | 5 ++++- flashinfer/utils.py | 21 +++++++++++------ 5 files changed, 59 insertions(+), 21 deletions(-) diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 29127f06ac..6f238382c4 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -180,7 +180,8 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) out_val = torch.empty( (*input.shape[:-1], input.shape[-1] // 2), dtype=torch.uint8, @@ -669,9 +670,11 @@ def fp4_quantize( assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) # get input device sm version - major, minor = get_compute_capability(input.device) + # major, minor = get_compute_capability(input.device) + major, minor = get_compute_capability(input.place) x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( input, global_scale, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2ce2a8b6d0..e7e8d7cda9 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,6 +20,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from ..autotuner import ( AutoTuner, @@ -350,11 +354,15 @@ def __init__( ) self.activation_type = activation_type + def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + dtype_str = str(dtype).split(".", 1)[-1] + return tvm_ffi.dtype(dtype_str) + if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( - x_dtype, - weight_dtype, - output_dtype, + paddle_dtype_to_tvm_ffi_dtype(x_dtype), + paddle_dtype_to_tvm_ffi_dtype(weight_dtype), + paddle_dtype_to_tvm_ffi_dtype(output_dtype), use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling, @@ -454,7 +462,8 @@ def cutlass_fused_moe( activation_type: ActivationType = ActivationType.Swiglu, ) -> List[torch.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -513,17 +522,22 @@ def cutlass_fused_moe( else moe_runner.fused_moe_runner.run_moe ) num_active_experts_per_node = torch.empty( - (1,), dtype=torch.int32, device=input.device + # (1,), dtype=torch.int32, device=input.device + (1,), + dtype=torch.int32, + device=input.place, ) experts_to_token_score = torch.empty( (fc2_expert_weights.shape[0], input.shape[0]), dtype=torch.float32, - device=input.device, + # device=input.device, + device=input.place, ) active_expert_global_ids = torch.empty( (fc2_expert_weights.shape[0],), dtype=torch.int32, - device=input.device, + # device=input.device, + device=input.place, ) min_latency_output = ( [ @@ -799,7 +813,8 @@ def cutlass_fused_moe( ) if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) num_rows = input.shape[0] if min_latency_mode: @@ -808,10 +823,16 @@ def cutlass_fused_moe( output_shape = (num_rows, hidden_size) if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + # output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = torch.empty(output_shape, dtype=output_dtype, device=input.place) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + # output, output_shape, output_dtype, input.device, "output" + output, + output_shape, + output_dtype, + input.place, + "output", ) return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe( diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 2eec7ac2ce..d2fbf5687a 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,6 +1,10 @@ import dataclasses import logging import os +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from contextlib import nullcontext from datetime import datetime from pathlib import Path diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 2c3a56d92b..0dab4fc584 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import List, Optional -import tvm_ffi +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi import torch from . import env as jit_env diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 936d08380c..f48d98dbe3 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -16,13 +16,12 @@ import functools import math +import os from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import torch import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version from .jit.spdlog import gen_spdlog_module @@ -249,6 +248,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache def get_compute_capability(device: torch.device) -> Tuple[int, int]: + return torch.device.cuda.get_device_capability(device.gpu_device_id()) if device.type != "cuda": raise ValueError("device must be a cuda device") return torch.cuda.get_device_capability(device.index) @@ -267,7 +267,13 @@ def _check_cached_qkv_data_type( ) -if TorchVersion(torch_version) < TorchVersion("2.4"): +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + +if use_paddle_compatible_api() or torch.torch_version.TorchVersion( + torch.torch_version.__version__ +) < torch.torch_version.TorchVersion("2.4"): def register_custom_op( name: str, @@ -522,7 +528,7 @@ def check_shape_dtype_device( expected_device: Optional[torch.device], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != torch.Size(expected_shape): raise ValueError( f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" ) @@ -530,7 +536,8 @@ def check_shape_dtype_device( raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + # if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( f"Invalid device of {name}: expected {expected_device}, got {x.device}" ) @@ -566,8 +573,8 @@ def set_log_level(lvl_str: str) -> None: @functools.cache def device_support_pdl(device: torch.device) -> bool: - if device.type != "cuda": - return False + # if device.type != "cuda": + # return False major, _ = get_compute_capability(device) return major >= 9 From 95f1bf52a00d82537bb4721dbcf75a90f3a5bf05 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 13 Oct 2025 02:36:22 +0000 Subject: [PATCH 2/4] remove torch from requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a31b6ebdc8..a71e497d28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ nvidia-ml-py packaging>=24.2 requests tabulate -torch +# torch tqdm From 955aedf4bd70739b7ee1179ed4addbc0f4121d04 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 23 Oct 2025 05:01:38 +0000 Subject: [PATCH 3/4] remove changes about import tvm_ffi --- flashinfer/fused_moe/core.py | 8 +++----- flashinfer/jit/core.py | 4 ---- flashinfer/jit/cpp_ext.py | 5 +---- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index e7e8d7cda9..14d1170f01 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,10 +20,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi from ..autotuner import ( AutoTuner, @@ -354,7 +350,9 @@ def __init__( ) self.activation_type = activation_type - def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + def paddle_dtype_to_tvm_ffi_dtype(dtype): + import tvm_ffi + dtype_str = str(dtype).split(".", 1)[-1] return tvm_ffi.dtype(dtype_str) diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index d2fbf5687a..2eec7ac2ce 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,10 +1,6 @@ import dataclasses import logging import os -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi from contextlib import nullcontext from datetime import datetime from pathlib import Path diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 0dab4fc584..2c3a56d92b 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,10 +10,7 @@ from pathlib import Path from typing import List, Optional -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi +import tvm_ffi import torch from . import env as jit_env From 7476b5a28acf36355fbe724620c79cab85b7ee41 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 4 Nov 2025 11:17:42 +0000 Subject: [PATCH 4/4] remove dtype conversion --- flashinfer/fused_moe/core.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 14d1170f01..a34b89597d 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -350,17 +350,11 @@ def __init__( ) self.activation_type = activation_type - def paddle_dtype_to_tvm_ffi_dtype(dtype): - import tvm_ffi - - dtype_str = str(dtype).split(".", 1)[-1] - return tvm_ffi.dtype(dtype_str) - if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( - paddle_dtype_to_tvm_ffi_dtype(x_dtype), - paddle_dtype_to_tvm_ffi_dtype(weight_dtype), - paddle_dtype_to_tvm_ffi_dtype(output_dtype), + x_dtype, + weight_dtype, + output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling,