Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions flashinfer/fp4_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def fp4_quantize_sm100(
- Scale factors tensor with shape determined by layout and sf_vec_size
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
out_val = torch.empty(
(*input.shape[:-1], input.shape[-1] // 2),
dtype=torch.uint8,
Expand Down Expand Up @@ -669,9 +670,11 @@ def fp4_quantize(

assert input.shape[-1] % sf_vec_size == 0
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
# get input device sm version
major, minor = get_compute_capability(input.device)
# major, minor = get_compute_capability(input.device)
major, minor = get_compute_capability(input.place)
x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
input,
global_scale,
Expand Down
27 changes: 20 additions & 7 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def cutlass_fused_moe(
activation_type: ActivationType = ActivationType.Swiglu,
) -> List[torch.Tensor]:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

Expand Down Expand Up @@ -513,17 +514,22 @@ def cutlass_fused_moe(
else moe_runner.fused_moe_runner.run_moe
)
num_active_experts_per_node = torch.empty(
(1,), dtype=torch.int32, device=input.device
# (1,), dtype=torch.int32, device=input.device
(1,),
dtype=torch.int32,
device=input.place,
)
experts_to_token_score = torch.empty(
(fc2_expert_weights.shape[0], input.shape[0]),
dtype=torch.float32,
device=input.device,
# device=input.device,
device=input.place,
)
active_expert_global_ids = torch.empty(
(fc2_expert_weights.shape[0],),
dtype=torch.int32,
device=input.device,
# device=input.device,
device=input.place,
)
min_latency_output = (
[
Expand Down Expand Up @@ -799,7 +805,8 @@ def cutlass_fused_moe(
)

if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
# enable_pdl = device_support_pdl(input.device)
enable_pdl = device_support_pdl(input.place)

num_rows = input.shape[0]
if min_latency_mode:
Expand All @@ -808,10 +815,16 @@ def cutlass_fused_moe(
output_shape = (num_rows, hidden_size)

if output is None:
output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
# output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
output = torch.empty(output_shape, dtype=output_dtype, device=input.place)
else:
check_shape_dtype_device(
output, output_shape, output_dtype, input.device, "output"
# output, output_shape, output_dtype, input.device, "output"
output,
output_shape,
output_dtype,
input.place,
"output",
)

return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe(
Expand Down
21 changes: 14 additions & 7 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

import functools
import math
import os
from enum import Enum
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

import torch
import torch.version
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

from .jit.spdlog import gen_spdlog_module

Expand Down Expand Up @@ -249,6 +248,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:

@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
return torch.device.cuda.get_device_capability(device.gpu_device_id())
if device.type != "cuda":
raise ValueError("device must be a cuda device")
return torch.cuda.get_device_capability(device.index)
Expand All @@ -267,7 +267,13 @@ def _check_cached_qkv_data_type(
)


if TorchVersion(torch_version) < TorchVersion("2.4"):
def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


if use_paddle_compatible_api() or torch.torch_version.TorchVersion(
torch.torch_version.__version__
) < torch.torch_version.TorchVersion("2.4"):

def register_custom_op(
name: str,
Expand Down Expand Up @@ -522,15 +528,16 @@ def check_shape_dtype_device(
expected_device: Optional[torch.device],
name: str,
) -> None:
if expected_shape and x.shape != torch.Size(expected_shape):
if expected_shape and tuple(x.shape) != torch.Size(expected_shape):
raise ValueError(
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
)
if expected_dtype and x.dtype != expected_dtype:
raise ValueError(
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
)
if expected_device and x.device != expected_device:
# if expected_device and x.device != expected_device:
if expected_device and x.place != expected_device:
raise ValueError(
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
)
Expand Down Expand Up @@ -566,8 +573,8 @@ def set_log_level(lvl_str: str) -> None:

@functools.cache
def device_support_pdl(device: torch.device) -> bool:
if device.type != "cuda":
return False
# if device.type != "cuda":
# return False
major, _ = get_compute_capability(device)
return major >= 9

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ nvidia-ml-py
packaging>=24.2
requests
tabulate
torch
# torch
tqdm
Loading