@@ -630,6 +630,11 @@ def extra_repr(self):
630630 return f"inner_k_tiles={ self .inner_k_tiles } "
631631
632632
633+ @dataclass (frozen = True )
634+ class Int4CPULayout (Layout ):
635+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
636+ return input
637+
633638@dataclass (frozen = True )
634639class Float8Layout (Layout ):
635640 mm_config : Optional [Float8MMConfig ] = None
@@ -1616,6 +1621,230 @@ def get_layout(self) -> Layout:
16161621 return self ._layout
16171622
16181623
1624+ @register_layout (Int4CPULayout )
1625+ class Int4CPUAQTTensorImpl (AQTTensorImpl ):
1626+ """
1627+ TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
1628+ used by tinygemm kernels `_weight_int4pack_mm_for_cpu`
1629+
1630+ It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
1631+ dimension: [n][k / 2] (uint8 dtype)
1632+ (unpacked Tensor shape is n * k)
1633+
1634+ Note: we also pack scale and zero point together here for tinygemm kernel
1635+
1636+ Note: technically Int4 CPU layout should be the layout for the underlying packed weight
1637+ (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used
1638+ in plain layout, we just created a layout for AQT right now, this could be improved if we split out
1639+ int4 aqt into a separate tensor subclass
1640+
1641+ fields:
1642+ packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout
1643+ scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor
1644+ """
1645+
1646+ def __new__ (
1647+ cls ,
1648+ packed_weight : torch .Tensor ,
1649+ scale_and_zero : torch .Tensor ,
1650+ transposed : bool ,
1651+ _layout : Layout ,
1652+ ):
1653+ kwargs = {}
1654+ kwargs ["device" ] = packed_weight .device
1655+ kwargs ["layout" ] = (
1656+ kwargs .get ("layout" )
1657+ if kwargs .get ("layout" , False )
1658+ else packed_weight .layout
1659+ )
1660+ kwargs ["dtype" ] = packed_weight .dtype
1661+ kwargs ["requires_grad" ] = False
1662+ shape = packed_weight .shape
1663+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
1664+
1665+ def __init__ (
1666+ self ,
1667+ packed_weight : torch .Tensor ,
1668+ scale_and_zero : torch .Tensor ,
1669+ transposed : bool ,
1670+ _layout : Layout ,
1671+ ):
1672+ self .packed_weight = packed_weight
1673+ self .scale_and_zero = scale_and_zero
1674+ self .transposed = False
1675+ self ._layout = _layout
1676+
1677+ def __tensor_flatten__ (self ):
1678+ return ["packed_weight" , "scale_and_zero" ], [self .transposed , self ._layout ]
1679+
1680+ @classmethod
1681+ def __tensor_unflatten__ (
1682+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
1683+ ):
1684+ packed_weight , scale_and_zero = (
1685+ tensor_data_dict ["packed_weight" ],
1686+ tensor_data_dict ["scale_and_zero" ],
1687+ )
1688+ (
1689+ transposed ,
1690+ _layout ,
1691+ ) = tensor_attributes
1692+ return cls (packed_weight , scale_and_zero , transposed , _layout )
1693+
1694+ @classmethod
1695+ def from_plain (
1696+ cls ,
1697+ int_data : torch .Tensor ,
1698+ scale : torch .Tensor ,
1699+ zero_point : Optional [torch .Tensor ],
1700+ _layout : Layout ,
1701+ ):
1702+ assert isinstance (_layout , Int4CPULayout )
1703+
1704+ assert (
1705+ int_data .dtype == torch .int32
1706+ ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
1707+ packed_weight = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
1708+ int_data , 1 # TODO:remove
1709+ )
1710+ scale = scale .reshape (int_data .shape [0 ], - 1 )
1711+ zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
1712+
1713+ scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point )
1714+ return cls (packed_weight , scale_and_zero , False , _layout )
1715+
1716+ def to (self , * args , ** kwargs ):
1717+ kwargs = self ._get_to_kwargs (* args , ** kwargs )
1718+ device = kwargs ["device" ]
1719+ return self .__class__ (
1720+ self .packed_weight .to (device ),
1721+ self .scale_and_zero .to (device ),
1722+ self .transposed ,
1723+ self ._layout ,
1724+ )
1725+
1726+ def _apply_fn_to_data (self , fn ):
1727+ # self.packed_weight = fn(self.packed_weight)
1728+ # self.scale_and_zero = fn(self.scale_and_zero)
1729+ # return self
1730+ return self .__class__ (
1731+ fn (self .packed_weight ),
1732+ fn (self .scale_and_zero ),
1733+ self .transposed ,
1734+ self ._layout ,
1735+ )
1736+
1737+ @classmethod
1738+ def __torch_dispatch__ (cls , func , types , args , kwargs ):
1739+ kwargs = {} if kwargs is None else kwargs
1740+
1741+ if func is aten .detach .default :
1742+ return return_and_correct_aliasing (
1743+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
1744+ )
1745+
1746+ if func is aten .clone .default :
1747+ return return_and_correct_aliasing (
1748+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
1749+ )
1750+
1751+ if func is aten .t .default :
1752+ """we don't need to repack the weight and just rely on external
1753+ shape being changed and record the status of transpose/no-transpose
1754+ """
1755+ transposed = Int4CPUAQTTensorImpl (
1756+ args [0 ].packed_weight ,
1757+ args [0 ].scale_and_zero ,
1758+ not args [0 ].transposed ,
1759+ args [0 ]._layout ,
1760+ )
1761+ return return_and_correct_aliasing (func , args , kwargs , transposed )
1762+
1763+ if func is aten .slice .Tensor :
1764+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
1765+ if dim == 0 :
1766+ int_data , scale , zero_point = self .get_plain ()
1767+ int_data = aten .slice .Tensor (int_data , dim , start , end , step )
1768+ # this is to handle padding
1769+ int_data = self ._layout .post_process (int_data )
1770+ sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
1771+ return return_and_correct_aliasing (func , args , kwargs , sliced )
1772+ elif dim == 1 :
1773+ int_data , scale , zero_point = self .get_plain ()
1774+ assert step == 1 , "Only step == 1 is supported in slicing right now"
1775+ data_len = int_data .shape [dim ]
1776+ scale_len = scale .shape [dim ]
1777+ ratio = data_len / scale_len
1778+ start_scale = int (start / ratio )
1779+ end_scale = int (end / ratio )
1780+
1781+ int_data = aten .slice .Tensor (int_data , dim , start , end , step )
1782+ # this is to handle padding
1783+ int_data = self ._layout .post_process (int_data )
1784+ scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
1785+ zero_point = aten .slice .Tensor (
1786+ zero_point , dim , start_scale , end_scale , step
1787+ )
1788+ sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
1789+ return sliced
1790+ else :
1791+ raise NotImplementedError (
1792+ f"Int4CPUAQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
1793+ )
1794+
1795+ raise NotImplementedError (
1796+ f"Int4CPUAQTTensorImpl dispatch: attempting to run { func } , this is not supported"
1797+ )
1798+
1799+ __torch_function__ = torch ._C ._disabled_torch_function_impl
1800+
1801+ def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1802+ from torchao .quantization .quant_primitives import (
1803+ ZeroPointDomain ,
1804+ quantize_affine ,
1805+ )
1806+ from torchao .quantization .utils import unpack_tinygemm_scales_and_zeros
1807+
1808+ scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
1809+
1810+ cur_shape = self .shape
1811+ assert len (cur_shape ) == 2
1812+ original_shape = (cur_shape [0 ], cur_shape [1 ] * 2 )
1813+ eye_shape = original_shape [1 ]
1814+ groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
1815+ block_size = (1 , groupsize )
1816+ device = self .device
1817+ original_dtype = torch .bfloat16
1818+ target_dtype = torch .int32
1819+ quant_min = 0
1820+ quant_max = 15
1821+ zero_point_domain = ZeroPointDomain .FLOAT
1822+ assert len (block_size ) == 2 and block_size [0 ] == 1
1823+ dequantized = torch .ops .aten ._weight_int4pack_mm_for_cpu (
1824+ torch .eye (eye_shape , device = device , dtype = original_dtype ),
1825+ self .packed_weight ,
1826+ groupsize ,
1827+ self .scale_and_zero ,
1828+ )
1829+ dequantized = dequantized .t ().contiguous ()
1830+ # TODO: move this to `unpack_tinygemm_scales_and_zeros`?
1831+ scale = scale .reshape (scale .shape [:- 1 ]).contiguous ()
1832+ zero = zero .reshape (zero .shape [:- 1 ]).contiguous ()
1833+ int_data = quantize_affine (
1834+ dequantized ,
1835+ block_size ,
1836+ scale ,
1837+ zero ,
1838+ target_dtype ,
1839+ quant_min ,
1840+ quant_max ,
1841+ zero_point_domain ,
1842+ )
1843+ return int_data , scale , zero
1844+
1845+ def get_layout (self ) -> Layout :
1846+ return self ._layout
1847+
16191848#####################################################
16201849# torch functional and aten operator implementation #
16211850#####################################################
0 commit comments