@@ -54,15 +54,16 @@ class IntxUnpackedTensor(TorchAOBaseTensor):
5454 Non-Tensor Attributes:
5555 target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8)
5656 block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
57+ dtype: the dtype of the dequantized Tensor
5758 """
5859
5960 tensor_data_names = ["qdata" , "scale" , "zero_point" ]
60- tensor_attribute_names = ["target_dtype" , "block_size" ]
61+ tensor_attribute_names = ["target_dtype" , "block_size" , "dtype" ]
6162
62- def __new__ (cls , qdata , scale , zero_point , target_dtype , block_size = None ):
63+ def __new__ (cls , qdata , scale , zero_point , target_dtype , block_size , dtype ):
6364 kwargs = {}
6465 kwargs ["device" ] = qdata .device
65- kwargs ["dtype" ] = scale . dtype
66+ kwargs ["dtype" ] = dtype
6667 kwargs ["requires_grad" ] = False
6768 shape = qdata .shape
6869 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
@@ -73,7 +74,8 @@ def __init__(
7374 scale ,
7475 zero_point ,
7576 target_dtype ,
76- block_size : Tuple [int ],
77+ block_size ,
78+ dtype ,
7779 ):
7880 assert qdata .dtype == torch .int8 , (
7981 f"qdata dtype must be int8, but got { qdata .dtype } "
@@ -97,6 +99,10 @@ def __init__(
9799 scale = scale .reshape (* n_blocks )
98100 zero_point = zero_point .reshape (* n_blocks )
99101
102+ assert dtype in _FLOAT_TYPES , (
103+ f"dtype must be one of { _FLOAT_TYPES } , but got { dtype } "
104+ )
105+
100106 self .qdata = qdata
101107 self .scale = scale
102108 self .zero_point = zero_point
@@ -123,6 +129,7 @@ def to(self, *args, **kwargs):
123129 else self .zero_point .to (device ),
124130 self .target_dtype ,
125131 self .block_size ,
132+ dtype ,
126133 )
127134
128135 @classmethod
@@ -166,6 +173,7 @@ def from_hp(
166173 zero_point = zero_point ,
167174 target_dtype = target_dtype ,
168175 block_size = block_size ,
176+ dtype = hp_tensor .dtype ,
169177 )
170178
171179 def dequantize (self ):
@@ -259,6 +267,7 @@ def _(func, types, args, kwargs):
259267 zero_point ,
260268 self .target_dtype ,
261269 new_block_size ,
270+ self .dtype ,
262271 )
263272 return return_and_correct_aliasing (func , args , kwargs , new )
264273
0 commit comments