Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 41edb0 to 7fd585
2 changes: 1 addition & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def remove_tvm_path(path):
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
)
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .utils import auto_detect_target, auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
Expand Down
14 changes: 7 additions & 7 deletions bitblas/base/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cdna import CDNA
from typing import Union
from tvm.target import Target
from bitblas.utils.target_detector import auto_detect_target


def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
Expand All @@ -22,12 +23,6 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
raise ValueError(f"Unsupported target: {target.kind.name}")


def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
return get_arch("cuda")


from .cpu import is_cpu_arch # noqa: F401
from .cuda import (
is_cuda_arch, # noqa: F401
Expand All @@ -38,4 +33,9 @@ def auto_infer_current_arch() -> TileDevice:
is_tensorcore_supported_precision, # noqa: F401
has_mma_support, # noqa: F401
)
from .cdna import is_cdna_arch # noqa: F401
from .cdna import is_cdna_arch, is_matrixcore_supported_precision # noqa: F401


def auto_infer_current_arch() -> TileDevice:
target = auto_detect_target()
return get_arch(target)
15 changes: 15 additions & 0 deletions bitblas/base/arch/cdna.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ def is_cdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, CDNA)


# AMD Matrix Core Configurations
cdna_matrixcore_supported = [
("float16", "float32"),
("int8", "int32"),
]


def is_matrixcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:

if is_cdna_arch(arch):
return (in_dtype, accum_dtype) in cdna_matrixcore_supported
else:
raise ValueError(f"Unsupported architecture: {arch}")


class CDNA(TileDevice):

def __init__(self, target: Union[Target, str]):
Expand Down
17 changes: 6 additions & 11 deletions bitblas/base/arch/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,35 @@ def is_cuda_arch(arch: TileDevice) -> bool:

def is_volta_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 70)
conditions.append(arch.sm_version < 80)
conditions.append(is_cuda_arch(arch) and arch.sm_version >= 70 and arch.sm_version < 80)
return all(conditions)


def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80 and arch.sm_version < 90)
conditions.append(is_cuda_arch(arch) and arch.sm_version >= 80 and arch.sm_version < 90)
return all(conditions)


def is_ada_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version == 89)
conditions.append(is_cuda_arch(arch) and arch.sm_version == 89)
return all(conditions)


def is_hopper_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version == 90)
conditions.append(is_cuda_arch(arch) and arch.sm_version == 90)
return all(conditions)


def has_mma_support(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80)
conditions.append(is_cuda_arch(arch) and arch.sm_version >= 80)
return all(conditions)


# NVIDIA Tensor Core Configurations
volta_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
Expand Down
4 changes: 2 additions & 2 deletions bitblas/benchmark/operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Dict, List, Tuple, Optional
from bitblas.ops import Operator, OperatorConfig
from bitblas.utils import get_default_cache_path
from bitblas import auto_detect_nvidia_target
from bitblas import auto_detect_target
from bitblas import tvm as tvm
from bitblas.cache import OperatorCache
import logging
Expand All @@ -21,7 +21,7 @@ class BitblasOperatorBenchmarkBase(ABC):
benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig, Optional[int]]]] = {}

# Currently we only support NVIDIA target for benchmarking
benchmark_target: str = auto_detect_nvidia_target()
benchmark_target: str = auto_detect_target()

# Benchmark results: a list of tuples, each containing latency and tuning time
benchmark_results: Dict[str, List[Tuple[Optional[float], Optional[float]]]] = {}
Expand Down
11 changes: 7 additions & 4 deletions bitblas/builder/wrapper/tl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def get_cuda_init_func(self):
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs

def get_stream_argument(self) -> Dict:
return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"}

def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
Expand Down Expand Up @@ -115,7 +118,7 @@ def update_lib_code(self, code: str):
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})

function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
function_args.append(self.get_stream_argument())
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])

Expand Down Expand Up @@ -223,7 +226,7 @@ def create_dispatch_func(self, code, function_informations):
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})

function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
function_args.append(self.get_stream_argument())

# Format the argument definitions for function declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
Expand Down Expand Up @@ -392,8 +395,8 @@ def get_hip_init_func(self):
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs

def get_stream_type(self, function_args):
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},)
def get_stream_argument(self) -> Dict:
return {"name": "stream=hipStreamDefault", "type": "hipStream_t"}


class TLWrapper(BaseWrapper):
Expand Down
4 changes: 2 additions & 2 deletions bitblas/cache/operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
from bitblas.utils import get_default_cache_path
from bitblas.utils import get_default_cache_path, auto_detect_target
from bitblas.ops.operator import OperatorConfig, Operator
from dataclasses import asdict
import os
Expand Down Expand Up @@ -186,7 +186,7 @@ def load_global_ops_cache(database_path=None, target=None):
if database_path is None:
database_path = get_database_path()
if target is None:
target = bitblas.auto_detect_nvidia_target()
target = auto_detect_target()
logger.info(f"Loading operators from database {database_path} for target {target}")
global_operator_cache.load_from_database(database_path, target)
return global_operator_cache
Expand Down
4 changes: 2 additions & 2 deletions bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
from bitblas.quantization.utils import general_compress
from bitblas import auto_detect_nvidia_target
from bitblas import auto_detect_target

BITBLAS_DATABASE_PATH = get_database_path()

Expand Down Expand Up @@ -240,7 +240,7 @@ def _configure_bitblas_matmul(
self.source_format = self.bitblas_matmul.source_format

def _get_or_create_bitblas_operator(self, config, enable_tuning):
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_TARGET = auto_detect_target()

if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
Expand Down
4 changes: 2 additions & 2 deletions bitblas/ops/general_flashatten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bitblas.base.base_scheduler import BaseScheduler
from ..operator import OperatorConfig, Operator, BaseKernelNameGenerator
from ...base.arch.cuda import CUDA
from ...utils import auto_detect_nvidia_target
from ...utils import auto_detect_target
from dataclasses import dataclass
from typing import Union, Tuple, Literal, Optional, Any
import logging
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
backend: str = "tl",
):
if target is None:
target = auto_detect_nvidia_target()
target = auto_detect_target()
logger.info(f"Auto detected target: {target}")

assert (config.Q_dtype
Expand Down
4 changes: 2 additions & 2 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils import retrieve_func_from_module
from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas.utils.target_detector import auto_detect_target
from dataclasses import dataclass
from ..ladder_permutate import LadderPermutate, LadderPermutateConfig
from ..quant_compress import QuantCompress, QuantCompressConfig
Expand Down Expand Up @@ -356,7 +356,7 @@ def __init__(
# if from database, we should disable default schedule
# to save compilation time
if target is None:
target = auto_detect_nvidia_target()
target = auto_detect_target()
logger.info(f"Auto detected target: {target}")

assert (config.A_dtype
Expand Down
43 changes: 37 additions & 6 deletions bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
from tvm.tir import PrimFunc
from bitblas.base.operator_common import TransformKind
from bitblas.base.base_scheduler import BaseScheduler
from bitblas.base.arch import (
TileDevice,
is_ampere_arch,
is_volta_arch,
is_tensorcore_supported_precision,
)
from bitblas.base.arch import (TileDevice, is_ampere_arch, is_volta_arch, is_cdna_arch,
is_tensorcore_supported_precision, is_matrixcore_supported_precision)
from dataclasses import dataclass
from bitblas.tl.base_hint import BaseTLHint

Expand Down Expand Up @@ -128,11 +124,46 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler:
else:
return self.matmul_simt_scheduler

def dispatch_cdna_scheduler(self, arch: TileDevice) -> BaseScheduler:
M = self.maybe_dynamic(self.M, "m")
N, K = self.N, self.K
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"

is_dynamic = self.is_dynamic
in_dtype, accum_dtype = (
self.in_dtype,
self.accum_dtype,
)
if self.weight_transform_kind != TransformKind.NonTransform:
raise ValueError(
f"Weight propagation {self.weight_transform_kind} is not supported for CDNA")
if in_dtype not in ["int8", "float16", "float32", "float64"]:
raise ValueError(f"Unsupported input data type: {in_dtype}")

if is_dynamic:
# Dynamic Dispatcher
if is_matrixcore_supported_precision(in_dtype, accum_dtype, arch):
return self.matmul_block_scheduler
else:
return self.matmul_simt_scheduler
else:
minimal_tensorcore_threshold: List[int, int, int] = [8, 16, 16]
if minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[
1] > N or minimal_tensorcore_threshold[2] > K:
return self.gemv_scheduler
elif is_matrixcore_supported_precision(in_dtype, accum_dtype, arch):
# Fine-grained scheduler (mma) is not implemented for CDNA
return self.matmul_block_scheduler
else:
return self.matmul_simt_scheduler

def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler:
if is_ampere_arch(arch):
return self.dispatch_ampere_scheduler(arch)
elif is_volta_arch(arch):
return self.dispatch_volta_scheduler(arch)
elif is_cdna_arch(arch):
return self.dispatch_cdna_scheduler(arch)
else:
raise ValueError(f"Unsupported architecture: {arch}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
TileDevice,
is_ampere_arch,
is_volta_arch,
is_cdna_arch,
is_tensorcore_supported_precision,
is_matrixcore_supported_precision,
)
from dataclasses import dataclass
from bitblas.tl.base_hint import BaseTLHint
Expand Down Expand Up @@ -143,11 +145,57 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler:
else:
return self.matmul_dequantize_simt_scheduler

def dispatch_cdna_scheduler(self, arch: TileDevice) -> BaseScheduler:
M = self.maybe_dynamic(self.M, "m")
N, K = self.N, self.K
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"

is_dynamic = self.is_dynamic
in_dtype, accum_dtype = (
self.in_dtype,
self.accum_dtype,
)
weight_transform_kind = self.weight_transform_kind
if is_dynamic:
# Dynamic Dispatcher
if is_matrixcore_supported_precision(in_dtype, accum_dtype, arch):
if weight_transform_kind != TransformKind.NonTransform:
raise NotImplementedError("Weight propagation is not supported for MatrixCore with Dequantization")
else:
raise NotImplementedError("Fine-grained scheduler is not supported for MatrixCore with Dequantization")s
return self.matmul_dequantize_fine_grained_scheduler
else:
if weight_transform_kind != TransformKind.NonTransform:
raise ValueError(
"Weight propagation is not supported for non-TensorCore architectures")
return self.matmul_dequantize_simt_scheduler
else:
minimal_tensorcore_threshold: List[int, int, int] = ([8, 16, 32] if accum_dtype
== "int32" else [8, 16, 16])
if (minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[1] > N or
minimal_tensorcore_threshold[2] > K):
if in_dtype == "int4":
raise ValueError("INT4 is not supported for non-TensorCore architectures")
if weight_transform_kind != TransformKind.NonTransform:
raise ValueError(
"Weight propagation is not supported for non-TensorCore architectures")
return self.gemv_dequantize_simt_scheduler
elif is_matrixcore_supported_precision(in_dtype, accum_dtype, arch):
if self.weight_transform_kind != TransformKind.NonTransform:
raise NotImplementedError(
"Weight propagation is not supported for MatrixCore with Dequantization")
else:
return self.matmul_dequantize_fine_grained_scheduler
else:
return self.matmul_dequantize_simt_scheduler

def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler:
if is_ampere_arch(arch):
return self.dispatch_ampere_scheduler(arch)
elif is_volta_arch(arch):
return self.dispatch_volta_scheduler(arch)
elif is_cdna_arch(arch):
return self.dispatch_cdna_scheduler(arch)
else:
raise ValueError(f"Unsupported architecture: {arch}")

Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _build_default_module(self, target: Target):
self._update_optimized_mod(scheduled_ir_module)
except Exception as apply_schedule_error:
self.scheduled_ir_module = None
logger.warning(
logger.exception(
APPLY_SCHEDULE_FAILED_MESSAGE.format(self.__class__.__name__, target, "default",
apply_schedule_error))

Expand Down
2 changes: 1 addition & 1 deletion bitblas/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401
from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401
from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401
from .target_detector import auto_detect_target, get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401
from .rtmod_analysis import get_annotated_device_mod # noqa: F401
from .weight_propagate import apply_transform_on_input # noqa: F401

Expand Down
Loading
Loading