From fd3017cf4ac233be9e69eea2e2e2523f4050ef17 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Sun, 9 Nov 2025 17:14:09 -0800 Subject: [PATCH 1/4] Move floatx_tensor_core to prototype/dtypes --- test/dtypes/test_floatx.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- .../floatx/floatx_tensor_core_layout.py | 688 +----------------- torchao/prototype/dtypes/__init__.py | 2 + torchao/prototype/dtypes/floatx/__init__.py | 12 + .../floatx/floatx_tensor_core_layout.py | 666 +++++++++++++++++ 6 files changed, 715 insertions(+), 657 deletions(-) create mode 100644 torchao/prototype/dtypes/floatx/__init__.py create mode 100644 torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index ab4a13d24c..9e79138671 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -19,7 +19,7 @@ from_scaled_tc_floatx, to_scaled_tc_floatx, ) -from torchao.dtypes.floatx.floatx_tensor_core_layout import ( +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( FloatxTensorCoreAQTTensorImpl, _pack_tc_floatx, _pack_tc_fp6, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 21f13729dd..27e58eee93 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,7 +21,7 @@ _linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl, ) -from torchao.dtypes.floatx.floatx_tensor_core_layout import ( +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index c7fb1e1a7c..37103f7dfc 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -3,664 +3,42 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from functools import reduce -from typing import Optional, Tuple -import torch -from torch import Tensor -from torch.utils._python_dispatch import ( - is_traceable_wrapper_subclass, - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.utils import ( - AQTTensorImpl, - Layout, -) -from torchao.prototype.custom_fp_utils import ( - _f32_to_floatx_unpacked, - _floatx_unpacked_to_f32, - _n_ones, +warnings.warn( + "Importing from torchao.dtypes.floatx.floatx_tensor_core_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ...' instead. " + "This import path will be removed in a future torchao release. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", + DeprecationWarning, + stacklevel=2, ) -aten = torch.ops.aten -_ONES_TABLE = [_n_ones(i) for i in range(8)] - - -def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce( - torch.bitwise_or, - [ - x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) - for i in range(8 // n_bits) - ], - ) - - -def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack( - [ - (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) - for i in range(8 // n_bits) - ], - dim=-1, - ).flatten(-2) - - -# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 -def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: - # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 - # thus, we need to reverse byte order within a uint32 word. - x = x.reshape(-1, 4).flip(1) - - x = _unpack(x, n_bits) - x = x.view(-1, 4 * (8 // n_bits)) - - if not undo: - bit_order = { - 1: [ - 1, - 5, - 9, - 13, - 17, - 21, - 25, - 29, - 3, - 7, - 11, - 15, - 19, - 23, - 27, - 31, - 0, - 4, - 8, - 12, - 16, - 20, - 24, - 28, - 2, - 6, - 10, - 14, - 18, - 22, - 26, - 30, - ], - 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], - 4: [1, 5, 3, 7, 0, 4, 2, 6], - }[n_bits] - - else: - # this is inverse of the above, obtained by running - # [v.index(i) for i in range(len(v))] - bit_order = { - 1: [ - 16, - 0, - 24, - 8, - 17, - 1, - 25, - 9, - 18, - 2, - 26, - 10, - 19, - 3, - 27, - 11, - 20, - 4, - 28, - 12, - 21, - 5, - 29, - 13, - 22, - 6, - 30, - 14, - 23, - 7, - 31, - 15, - ], - 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], - 4: [4, 0, 6, 2, 5, 1, 7, 3], - }[n_bits] - - x = x[:, bit_order] - x = _pack(x, n_bits) - - # reverse byte order within a uint32 word again. - x = x.reshape(-1, 4).flip(1) - return x.flatten() - - -# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing -# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h -def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - assert tensor.ndim == 2, tensor.dtype == torch.uint8 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - # Pass 1 from original code - tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) - tensor = tensor.reshape(-1, 32, 2) - tensor = tensor.permute(1, 0, 2) - tensor = tensor.flatten() - - used_bits = 0 - fragments = [] - - for y in [1, 2, 4]: - if nbits & y: - mask = (1 << y) - 1 - tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask - tensor_ybit = _pack(tensor_ybit, y) - - tensor_ybit = ( - tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) - ) # Pass 2 from original code - tensor_ybit = _bit_interleave( - tensor_ybit.flatten(), y - ) # Pass 3 from original code - fragments.append(tensor_ybit) - used_bits += y - - return torch.cat(fragments, dim=0).view(M, -1) - - -# more optimized version of _pack_tc_floatx() for FP6 by merging ops -def _pack_tc_fp6(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2, tensor.dtype == torch.uint8 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) - tensor = tensor.flip(3) - - tensor_2bit = (tensor >> 4) & 0b11 - tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) - tensor_2bit = _pack(tensor_2bit.flatten(), 2) - - tensor_4bit = tensor & 0b1111 - tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) - tensor_4bit = _pack(tensor_4bit.flatten(), 4) - - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) - - -# currently only optimize for TC-FP6 packing -def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - if nbits == 6: - return _pack_tc_fp6(tensor) - return _pack_tc_floatx(tensor, nbits) - - -def to_scaled_tc_floatx( - tensor: Tensor, ebits: int, mbits: int -) -> Tuple[Tensor, Tensor]: - # _n_ones() is not compatible with torch.compile() due to << operator - # https://github.com/pytorch/pytorch/issues/119152 - # exp_bias = _n_ones(ebits - 1) - # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) - - # workaround: global lookup table - exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( - _ONES_TABLE[mbits + 1] / (2**mbits) - ) - - dtype = tensor.dtype - tensor = tensor.float() - scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) - tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) - return tensor_tc_floatx, scale.to(dtype) - - -# inverse of _pack_tc_floatx() -def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - assert tensor.ndim == 2 and tensor.dtype == torch.uint8 - M = tensor.shape[0] - size = tensor.numel() - tensor = tensor.flatten() - offset = 0 - used_bits = 0 - - tensor_floatx = None - - for y in [1, 2, 4]: - if nbits & y: - size_ybit = size // nbits * y - tensor_ybit = tensor[offset : offset + size_ybit] - offset += size_ybit - - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = ( - tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) - ) # undo Pass 2 - - tensor_ybit = _unpack(tensor_ybit.flatten(), y) - tensor_ybit = tensor_ybit << (nbits - used_bits - y) - used_bits += y - - if tensor_floatx is None: - tensor_floatx = tensor_ybit - else: - tensor_floatx |= tensor_ybit - - # undo Pass 1 - tensor_floatx = tensor_floatx.view(32, -1, 2).permute(1, 0, 2) - tensor_floatx = tensor_floatx.reshape(M // 64, -1, 4, 2, 2, 8, 8) - tensor_floatx = tensor_floatx.permute(0, 2, 4, 5, 1, 3, 6) - tensor_floatx = tensor_floatx.reshape(M, -1) - return tensor_floatx - - -# more optimized version of _unpack_tc_floatx() for FP6 by merging ops -# inverse of _unpack_tc_fp6() -def _unpack_tc_fp6(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 and tensor.dtype == torch.uint8 - M = tensor.shape[0] - N = tensor.shape[1] // 3 * 4 - assert (M % 64 == 0) and (N % 64 == 0) - size_2bit = M * N // 4 - size_4bit = M * N // 2 - tensor = tensor.view(-1) - assert tensor.numel() == size_2bit + size_4bit - - tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) - - tensor_2bit = _unpack(tensor_2bit, 2) - tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) - tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) - - tensor_4bit = _unpack(tensor_4bit, 4) - tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) - tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) - - tensor_fp6 = (tensor_2bit << 4) | tensor_4bit - tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) - return tensor_fp6 - - -def unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - if nbits == 6: - return _unpack_tc_fp6(tensor) - return _unpack_tc_floatx(tensor, nbits) - - -def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: - floatx_unpacked = unpack_tc_floatx(tensor, 1 + ebits + mbits) - tensor = _floatx_unpacked_to_f32(floatx_unpacked, ebits, mbits) - if scale is not None: - tensor = tensor * scale.float().view(-1, 1) - return tensor - +# Re-export all public symbols from the new location for backward compatibility +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( + FloatxTensorCoreAQTTensorImpl, + FloatxTensorCoreLayout, + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, + _pack_tc_floatx, + _pack_tc_fp6, + from_scaled_tc_floatx, + pack_tc_floatx, + to_scaled_tc_floatx, + unpack_tc_floatx, +) -# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py -_SPLIT_K_MAP = [ - { # tokens: [1, 64] - 3072: 18, - 4096: 13, - 5120: 10, - 6144: 9, - 8192: 6, - 10240: 5, - 14336: 7, - 28672: 7, - 57344: 7, - }, - { # tokens: [65:128] - 3072: 9, - 4096: 6, - 5120: 5, - 6144: 9, - 8192: 3, - 10240: 5, - 14336: 7, - 28672: 7, - 57344: 6, - }, - { # tokens: [129:192] - 3072: 6, - 4096: 4, - 5120: 7, - 6144: 3, - 8192: 2, - 10240: 5, - 14336: 5, - 28672: 5, - 57344: 4, - }, - { # tokens: [193:256] - 3072: 9, - 4096: 3, - 5120: 5, - 6144: 2, - 8192: 5, - 10240: 4, - 14336: 8, - 28672: 6, - 57344: 4, - }, - { # tokens: [257:320] - 3072: 7, - 4096: 5, - 5120: 2, - 6144: 5, - 8192: 4, - 10240: 1, - 14336: 3, - 28672: 3, - 57344: 4, - }, - { # tokens: [321:384] - 3072: 3, - 4096: 2, - 5120: 5, - 6144: 3, - 8192: 1, - 10240: 8, - 14336: 3, - 28672: 4, - 57344: 3, - }, - { # tokens: [385:448] - 3072: 5, - 4096: 7, - 5120: 3, - 6144: 5, - 8192: 7, - 10240: 3, - 14336: 1, - 28672: 1, - 57344: 3, - }, - { # tokens: [449:512] - 3072: 2, - 4096: 5, - 5120: 4, - 6144: 1, - 8192: 5, - 10240: 2, - 14336: 6, - 28672: 4, - 57344: 1, - }, - { # tokens: [513:576] - 3072: 2, - 4096: 3, - 5120: 1, - 6144: 1, - 8192: 3, - 10240: 3, - 14336: 3, - 28672: 1, - 57344: 1, - }, - { # tokens: [577:640] - 3072: 5, - 4096: 4, - 5120: 1, - 6144: 4, - 8192: 2, - 10240: 1, - 14336: 1, - 28672: 1, - 57344: 1, - }, - { # tokens: [641:704] - 3072: 3, - 4096: 1, - 5120: 2, - 6144: 2, - 8192: 1, - 10240: 2, - 14336: 1, - 28672: 1, - 57344: 1, - }, - { # tokens: [705:768] - 3072: 3, - 4096: 1, - 5120: 3, - 6144: 2, - 8192: 1, - 10240: 1, - 14336: 1, - 28672: 1, - 57344: 1, - }, +__all__ = [ + "FloatxTensorCoreAQTTensorImpl", + "FloatxTensorCoreLayout", + "_linear_f16_bf16_act_floatx_weight_check", + "_linear_f16_bf16_act_floatx_weight_impl", + "_pack_tc_floatx", + "_pack_tc_fp6", + "from_scaled_tc_floatx", + "pack_tc_floatx", + "to_scaled_tc_floatx", + "unpack_tc_floatx", ] - - -# quantization api integrations -@dataclass(frozen=True) -class FloatxTensorCoreLayout(Layout): - """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). - This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. - """ - - ebits: int - mbits: int - - -@register_layout(FloatxTensorCoreLayout) -class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): - """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), - it has a internal tensor field of "packed_floatx_data", which is packed from the - uint8 unpacked data (the output of `_quantize_affine_floatx` operator) - - The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 - github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - - At a high level packing is done by grouping bits into 1 bit fragments (shards), 2 bit fragments and - 4 bit fragments each fragments are packed separately and concatenated together. - For example for 6 bit dtype, we can extract the first 4 bits for all elements and pack them together - in a fragment, and extract the last 2 bits for all elements and pack them into fragment, in the end - we concatenate the fragments together. - - If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be - (M, N // 8 * nbit) - - FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of - (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor - FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) - """ - - def __new__( - cls, - packed_floatx_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - assert packed_floatx_data.ndim == 2 - assert packed_floatx_data.dtype == torch.uint8 - shape = ( - packed_floatx_data.shape[0], - packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, - ) - kwargs = {} - kwargs["device"] = packed_floatx_data.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_floatx_data.layout - ) - kwargs["dtype"] = packed_floatx_data.dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_floatx_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - self.packed_floatx_data = packed_floatx_data - self.scale = scale - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_floatx_data", "scale"], [self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_floatx_data, scale = ( - tensor_data_dict["packed_floatx_data"], - tensor_data_dict["scale"], - ) - (_layout,) = tensor_attributes - return cls(packed_floatx_data, scale, _layout) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx( - self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits - ) - return unpacked_floatx_data, self.scale - - @classmethod - def from_plain( - cls, - unpacked_floatx_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - """ - Format for `unpacked_floatx_data` will be: - zero padding bits | sign bit | exponent bits | mantissa bits - - For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent - bit, M is mantissa bit - """ - assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx( - unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits - ) - return cls(packed_floatx_data, scale, _layout) - - def __repr__(self): - unpacked_floatx_data, scale = self.get_plain() - _layout = self.get_layout() - return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, _layout={_layout})" - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_floatx_data), - fn(self.scale), - self._layout, - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.packed_floatx_data.to(device), - self.scale.to(device), - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten._to_copy.default: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: x.to(device=kwargs.pop("device", None)) - ), - ) - - raise NotImplementedError( - f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_layout(self) -> Layout: - return self._layout - - -def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - return ( - # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and - # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) - and ( - # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) - or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) - or - # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) - or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) - ) - ) - - -def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.ops import quant_llm_linear - - act = input_tensor - weight = weight_tensor - - out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim) - - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - bsize = act_reshaped.shape[0] - splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - out = quant_llm_linear( - weight._layout.ebits, - weight._layout.mbits, - act_reshaped, - weight.tensor_impl.packed_floatx_data, - weight.tensor_impl.scale, - splitK=splitK, - ) - - if bias is not None: - out += bias - - return out.view(*act.shape[:-1], out_dim).to(act.dtype) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 294c7d0b15..86c5d25b41 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -12,6 +12,7 @@ MarlinQQQTensor, to_marlinqqq_quantized_intx, ) +from .floatx import FloatxTensorCoreLayout __all__ = [ "BlockSparseLayout", @@ -20,4 +21,5 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", + "FloatxTensorCoreLayout", ] diff --git a/torchao/prototype/dtypes/floatx/__init__.py b/torchao/prototype/dtypes/floatx/__init__.py new file mode 100644 index 0000000000..76dec3050e --- /dev/null +++ b/torchao/prototype/dtypes/floatx/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .floatx_tensor_core_layout import FloatxTensorCoreLayout + + +__all__ = [ + "FloatxTensorCoreLayout" +] diff --git a/torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py new file mode 100644 index 0000000000..c7fb1e1a7c --- /dev/null +++ b/torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py @@ -0,0 +1,666 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from functools import reduce +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import ( + AQTTensorImpl, + Layout, +) +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + _n_ones, +) + +aten = torch.ops.aten +_ONES_TABLE = [_n_ones(i) for i in range(8)] + + +def _pack(x: Tensor, n_bits: int) -> Tensor: + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) + + +def _unpack(x: Tensor, n_bits: int) -> Tensor: + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) + + +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 +def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: + # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 + # thus, we need to reverse byte order within a uint32 word. + x = x.reshape(-1, 4).flip(1) + + x = _unpack(x, n_bits) + x = x.view(-1, 4 * (8 // n_bits)) + + if not undo: + bit_order = { + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], + 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], + 4: [1, 5, 3, 7, 0, 4, 2, 6], + }[n_bits] + + else: + # this is inverse of the above, obtained by running + # [v.index(i) for i in range(len(v))] + bit_order = { + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], + 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], + 4: [4, 0, 6, 2, 5, 1, 7, 3], + }[n_bits] + + x = x[:, bit_order] + x = _pack(x, n_bits) + + # reverse byte order within a uint32 word again. + x = x.reshape(-1, 4).flip(1) + return x.flatten() + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h +def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + # Pass 1 from original code + tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) + tensor = tensor.reshape(-1, 32, 2) + tensor = tensor.permute(1, 0, 2) + tensor = tensor.flatten() + + used_bits = 0 + fragments = [] + + for y in [1, 2, 4]: + if nbits & y: + mask = (1 << y) - 1 + tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask + tensor_ybit = _pack(tensor_ybit, y) + + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code + fragments.append(tensor_ybit) + used_bits += y + + return torch.cat(fragments, dim=0).view(M, -1) + + +# more optimized version of _pack_tc_floatx() for FP6 by merging ops +def _pack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor = tensor.flip(3) + + tensor_2bit = (tensor >> 4) & 0b11 + tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) + tensor_2bit = _pack(tensor_2bit.flatten(), 2) + + tensor_4bit = tensor & 0b1111 + tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) + tensor_4bit = _pack(tensor_4bit.flatten(), 4) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) + + +# currently only optimize for TC-FP6 packing +def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _pack_tc_fp6(tensor) + return _pack_tc_floatx(tensor, nbits) + + +def to_scaled_tc_floatx( + tensor: Tensor, ebits: int, mbits: int +) -> Tuple[Tensor, Tensor]: + # _n_ones() is not compatible with torch.compile() due to << operator + # https://github.com/pytorch/pytorch/issues/119152 + # exp_bias = _n_ones(ebits - 1) + # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + # workaround: global lookup table + exp_bias = _ONES_TABLE[ebits - 1] + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) + + dtype = tensor.dtype + tensor = tensor.float() + scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) + return tensor_tc_floatx, scale.to(dtype) + + +# inverse of _pack_tc_floatx() +def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + size = tensor.numel() + tensor = tensor.flatten() + offset = 0 + used_bits = 0 + + tensor_floatx = None + + for y in [1, 2, 4]: + if nbits & y: + size_ybit = size // nbits * y + tensor_ybit = tensor[offset : offset + size_ybit] + offset += size_ybit + + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 + + tensor_ybit = _unpack(tensor_ybit.flatten(), y) + tensor_ybit = tensor_ybit << (nbits - used_bits - y) + used_bits += y + + if tensor_floatx is None: + tensor_floatx = tensor_ybit + else: + tensor_floatx |= tensor_ybit + + # undo Pass 1 + tensor_floatx = tensor_floatx.view(32, -1, 2).permute(1, 0, 2) + tensor_floatx = tensor_floatx.reshape(M // 64, -1, 4, 2, 2, 8, 8) + tensor_floatx = tensor_floatx.permute(0, 2, 4, 5, 1, 3, 6) + tensor_floatx = tensor_floatx.reshape(M, -1) + return tensor_floatx + + +# more optimized version of _unpack_tc_floatx() for FP6 by merging ops +# inverse of _unpack_tc_fp6() +def _unpack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + N = tensor.shape[1] // 3 * 4 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + tensor = tensor.view(-1) + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + + tensor_2bit = _unpack(tensor_2bit, 2) + tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) + + tensor_4bit = _unpack(tensor_4bit, 4) + tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) + return tensor_fp6 + + +def unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _unpack_tc_fp6(tensor) + return _unpack_tc_floatx(tensor, nbits) + + +def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: + floatx_unpacked = unpack_tc_floatx(tensor, 1 + ebits + mbits) + tensor = _floatx_unpacked_to_f32(floatx_unpacked, ebits, mbits) + if scale is not None: + tensor = tensor * scale.float().view(-1, 1) + return tensor + + +# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +_SPLIT_K_MAP = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7, + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6, + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4, + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4, + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4, + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3, + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3, + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1, + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1, + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1, + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1, + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1, + }, +] + + +# quantization api integrations +@dataclass(frozen=True) +class FloatxTensorCoreLayout(Layout): + """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). + This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. + """ + + ebits: int + mbits: int + + +@register_layout(FloatxTensorCoreLayout) +class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): + """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), + it has a internal tensor field of "packed_floatx_data", which is packed from the + uint8 unpacked data (the output of `_quantize_affine_floatx` operator) + + The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 + github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm + + At a high level packing is done by grouping bits into 1 bit fragments (shards), 2 bit fragments and + 4 bit fragments each fragments are packed separately and concatenated together. + For example for 6 bit dtype, we can extract the first 4 bits for all elements and pack them together + in a fragment, and extract the last 2 bits for all elements and pack them into fragment, in the end + we concatenate the fragments together. + + If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be + (M, N // 8 * nbit) + + FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of + (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 + it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor + FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) + """ + + def __new__( + cls, + packed_floatx_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + assert packed_floatx_data.ndim == 2 + assert packed_floatx_data.dtype == torch.uint8 + shape = ( + packed_floatx_data.shape[0], + packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, + ) + kwargs = {} + kwargs["device"] = packed_floatx_data.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_floatx_data.layout + ) + kwargs["dtype"] = packed_floatx_data.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_floatx_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + self.packed_floatx_data = packed_floatx_data + self.scale = scale + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_floatx_data", "scale"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_floatx_data, scale = ( + tensor_data_dict["packed_floatx_data"], + tensor_data_dict["scale"], + ) + (_layout,) = tensor_attributes + return cls(packed_floatx_data, scale, _layout) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: + unpacked_floatx_data = unpack_tc_floatx( + self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits + ) + return unpacked_floatx_data, self.scale + + @classmethod + def from_plain( + cls, + unpacked_floatx_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + """ + Format for `unpacked_floatx_data` will be: + zero padding bits | sign bit | exponent bits | mantissa bits + + For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent + bit, M is mantissa bit + """ + assert isinstance(_layout, FloatxTensorCoreLayout) + packed_floatx_data = pack_tc_floatx( + unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits + ) + return cls(packed_floatx_data, scale, _layout) + + def __repr__(self): + unpacked_floatx_data, scale = self.get_plain() + _layout = self.get_layout() + return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, _layout={_layout})" + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_floatx_data), + fn(self.scale), + self._layout, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.packed_floatx_data.to(device), + self.scale.to(device), + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), + ) + + raise NotImplementedError( + f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_layout(self) -> Layout: + return self._layout + + +def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + + return ( + # input is native float32 tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and + # weight is floatx Tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) + and ( + # weight is using fp6 quantization + (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) + or + # weight is using fp5 quantization + (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) + ) + ) + + +def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): + from torchao.ops import quant_llm_linear + + act = input_tensor + weight = weight_tensor + + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim) + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight._layout.ebits, + weight._layout.mbits, + act_reshaped, + weight.tensor_impl.packed_floatx_data, + weight.tensor_impl.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) From bb1a33ebf9c99778c2c4927c92b0edf415696a9a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 10 Nov 2025 10:53:31 -0800 Subject: [PATCH 2/4] updates --- benchmarks/benchmark_fp6.py | 2 +- docs/source/api_ref_dtypes.rst | 2 +- test/dtypes/test_floatx.py | 2 +- torchao/dtypes/affine_quantized_tensor.py | 5 +++-- torchao/quantization/quant_api.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index c22eba9e1a..4aac4b952f 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -9,7 +9,7 @@ from tqdm import tqdm from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreLayout +from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout from torchao.utils import benchmark_torch_function_in_microseconds diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 58ad4ee8a4..727b659259 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -20,7 +20,6 @@ Layouts and Tensor Subclasses TensorCoreTiledLayout Float8Layout FloatxTensor - FloatxTensorCoreLayout MarlinSparseLayout UintxLayout Int4CPULayout @@ -53,6 +52,7 @@ Prototype Int8DynamicActInt4WeightCPULayout MarlinQQQTensor MarlinQQQLayout + FloatxTensorCoreLayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 9e79138671..8e96ebe74a 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -14,7 +14,7 @@ run_tests, ) -from torchao.dtypes.floatx import ( +from torchao.prototype.dtypes.floatx import ( FloatxTensorCoreLayout, from_scaled_tc_floatx, to_scaled_tc_floatx, diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 0d7ed8d9e2..d777b9c2ba 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -136,7 +136,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import Float8Layout, FloatxTensorCoreLayout + from torchao.dtypes.floatx import Float8Layout + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() @@ -539,7 +540,7 @@ def from_hp_to_fpx( _layout: Layout, ): """Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.""" - from torchao.dtypes.floatx import FloatxTensorCoreLayout + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout assert isinstance(_layout, FloatxTensorCoreLayout), ( f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e3a75bbb3e..d0bf725469 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2395,7 +2395,7 @@ def _fpx_weight_only_transform( module = _unwrap_float8_linear(module) from torchao.dtypes import to_affine_quantized_fpx - from torchao.dtypes.floatx import FloatxTensorCoreLayout + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" out_dim, in_dim = weight.shape From 727a21730ddb07054656897f169427766751178c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 10 Nov 2025 10:58:43 -0800 Subject: [PATCH 3/4] Lint fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_floatx.py | 8 ++++---- torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- torchao/prototype/dtypes/__init__.py | 2 +- torchao/prototype/dtypes/floatx/__init__.py | 5 +---- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 8e96ebe74a..a3dd4d19e3 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -14,6 +14,10 @@ run_tests, ) +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, +) from torchao.prototype.dtypes.floatx import ( FloatxTensorCoreLayout, from_scaled_tc_floatx, @@ -24,10 +28,6 @@ _pack_tc_floatx, _pack_tc_fp6, ) -from torchao.prototype.custom_fp_utils import ( - _f32_to_floatx_unpacked, - _floatx_unpacked_to_f32, -) from torchao.quantization import ( FPXWeightOnlyConfig, quantize_, diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d777b9c2ba..3303bd5267 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -136,7 +136,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import Float8Layout + from torchao.dtypes.floatx import Float8Layout from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout if isinstance(self._layout, FloatxTensorCoreLayout): diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 27e58eee93..29e4866fc4 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,10 +21,6 @@ _linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl, ) -from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( - _linear_f16_bf16_act_floatx_weight_check, - _linear_f16_bf16_act_floatx_weight_impl, -) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -76,6 +72,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, +) from torchao.prototype.dtypes.uintx.block_sparse_layout import ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 86c5d25b41..ae4e8c0040 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from .floatx import FloatxTensorCoreLayout from .uintx import ( BlockSparseLayout, CutlassInt4PackedLayout, @@ -12,7 +13,6 @@ MarlinQQQTensor, to_marlinqqq_quantized_intx, ) -from .floatx import FloatxTensorCoreLayout __all__ = [ "BlockSparseLayout", diff --git a/torchao/prototype/dtypes/floatx/__init__.py b/torchao/prototype/dtypes/floatx/__init__.py index 76dec3050e..5d8eb8dacf 100644 --- a/torchao/prototype/dtypes/floatx/__init__.py +++ b/torchao/prototype/dtypes/floatx/__init__.py @@ -6,7 +6,4 @@ from .floatx_tensor_core_layout import FloatxTensorCoreLayout - -__all__ = [ - "FloatxTensorCoreLayout" -] +__all__ = ["FloatxTensorCoreLayout"] From 8864bb502e0956807ce5aa72a45cba70784d757d Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 11 Nov 2025 19:27:49 -0800 Subject: [PATCH 4/4] minor fixes --- torchao/dtypes/affine_quantized_tensor_ops.py | 8 -------- torchao/prototype/dtypes/floatx/__init__.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 4a13909475..730d33d2c6 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,14 +21,6 @@ _linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl, ) -from torchao.prototype.dtypes.uintx.gemlite_layout import ( - _linear_fp_act_int4_weight_gemlite_check, - _linear_fp_act_int4_weight_gemlite_impl, -) -from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( - _linear_f16_bf16_act_floatx_weight_check, - _linear_f16_bf16_act_floatx_weight_impl, -) from torchao.dtypes.uintx.int4_cpu_layout import ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, diff --git a/torchao/prototype/dtypes/floatx/__init__.py b/torchao/prototype/dtypes/floatx/__init__.py index 5d8eb8dacf..edd045f8a9 100644 --- a/torchao/prototype/dtypes/floatx/__init__.py +++ b/torchao/prototype/dtypes/floatx/__init__.py @@ -4,6 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from .floatx_tensor_core_layout import FloatxTensorCoreLayout +from .floatx_tensor_core_layout import ( + FloatxTensorCoreLayout, + from_scaled_tc_floatx, + to_scaled_tc_floatx, +) -__all__ = ["FloatxTensorCoreLayout"] +__all__ = [ + "FloatxTensorCoreLayout", + "to_scaled_tc_floatx", + "from_scaled_tc_floatx", +]