Skip to content

Commit f6c9d09

Browse files
committed
up
1 parent 143fe91 commit f6c9d09

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,16 @@ class IntxUnpackedTensor(TorchAOBaseTensor):
5454
Non-Tensor Attributes:
5555
target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8)
5656
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
57+
dtype: the dtype of the dequantized Tensor
5758
"""
5859

5960
tensor_data_names = ["qdata", "scale", "zero_point"]
60-
tensor_attribute_names = ["target_dtype", "block_size"]
61+
tensor_attribute_names = ["target_dtype", "block_size", "dtype"]
6162

62-
def __new__(cls, qdata, scale, zero_point, target_dtype, block_size=None):
63+
def __new__(cls, qdata, scale, zero_point, target_dtype, block_size, dtype):
6364
kwargs = {}
6465
kwargs["device"] = qdata.device
65-
kwargs["dtype"] = scale.dtype
66+
kwargs["dtype"] = dtype
6667
kwargs["requires_grad"] = False
6768
shape = qdata.shape
6869
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
@@ -73,7 +74,8 @@ def __init__(
7374
scale,
7475
zero_point,
7576
target_dtype,
76-
block_size: Tuple[int],
77+
block_size,
78+
dtype,
7779
):
7880
assert qdata.dtype == torch.int8, (
7981
f"qdata dtype must be int8, but got {qdata.dtype}"
@@ -97,6 +99,10 @@ def __init__(
9799
scale = scale.reshape(*n_blocks)
98100
zero_point = zero_point.reshape(*n_blocks)
99101

102+
assert dtype in _FLOAT_TYPES, (
103+
f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}"
104+
)
105+
100106
self.qdata = qdata
101107
self.scale = scale
102108
self.zero_point = zero_point
@@ -123,6 +129,7 @@ def to(self, *args, **kwargs):
123129
else self.zero_point.to(device),
124130
self.target_dtype,
125131
self.block_size,
132+
dtype,
126133
)
127134

128135
@classmethod
@@ -166,6 +173,7 @@ def from_hp(
166173
zero_point=zero_point,
167174
target_dtype=target_dtype,
168175
block_size=block_size,
176+
dtype=hp_tensor.dtype,
169177
)
170178

171179
def dequantize(self):
@@ -259,6 +267,7 @@ def _(func, types, args, kwargs):
259267
zero_point,
260268
self.target_dtype,
261269
new_block_size,
270+
self.dtype,
262271
)
263272
return return_and_correct_aliasing(func, args, kwargs, new)
264273

0 commit comments

Comments
 (0)