2424 pack_scales_and_zeros ,
2525)
2626
27+ from torchao .dtypes .utils import is_device
28+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
29+
2730
2831logger : logging .Logger = logging .getLogger (__name__ )
2932
@@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122125 input .dtype
123126 ) # cast back to input.dtype
124127 else :
125- c = torch .ops .aten ._weight_int4pack_mm (
126- input ,
127- weight_int4pack ,
128- groupsize ,
129- scales_and_zeros ,
130- )
128+ if TORCH_VERSION_AT_LEAST_2_6 :
129+ c = torch .ops .aten ._weight_int4pack_mm_for_cpu (
130+ input ,
131+ weight_int4pack ,
132+ groupsize ,
133+ scales_and_zeros ,
134+ )
135+ else :
136+ c = torch .ops .aten ._weight_int4pack_mm (
137+ input ,
138+ weight_int4pack ,
139+ groupsize ,
140+ scales_and_zeros ,
141+ )
131142 new_shape = origin_input_size [:- 1 ] + (out_features ,)
132143 c = c .reshape (new_shape )
133144 return c
@@ -178,16 +189,27 @@ def __init__(
178189 ), "must specify both weights and scales_and_zeros, or neither"
179190
180191 if weight is None :
181- weight = torch .empty (
182- (
183- out_features // 8 ,
184- in_features // (inner_k_tiles * 16 ),
185- 32 ,
186- inner_k_tiles // 2 ,
187- ),
188- dtype = torch .int32 ,
189- device = device ,
190- )
192+ if is_device (device , "cpu" ):
193+ weight = torch .empty (
194+ (
195+ out_features ,
196+ in_features // 2 ,
197+ ),
198+ dtype = torch .uint8 ,
199+ device = device ,
200+ )
201+ else :
202+ weight = torch .empty (
203+ (
204+ out_features // 8 ,
205+ in_features // (inner_k_tiles * 16 ),
206+ 32 ,
207+ inner_k_tiles // 2 ,
208+ ),
209+ dtype = torch .int32 ,
210+ device = device ,
211+ )
212+
191213 scales_and_zeros = torch .empty (
192214 (in_features // groupsize , out_features , 2 ),
193215 dtype = get_precision (),
@@ -223,12 +245,17 @@ def _prepare_weight_and_scales_and_zeros(
223245 weight_int32 , scales_and_zeros = group_quantize_tensor (
224246 weight_bf16 , n_bit = 4 , groupsize = groupsize
225247 )
226- weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
227- torch .uint8
228- )
229- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
230- weight_uint8 , inner_k_tiles
231- )
248+ if is_device (weight_int32 .device .type , "cpu" ) and TORCH_VERSION_AT_LEAST_2_6 :
249+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
250+ weight_int32 , inner_k_tiles
251+ )
252+ else :
253+ weight_uint8 = (weight_int32 [::, ::2 ] << 4 | weight_int32 [::, 1 ::2 ]).to (
254+ torch .uint8
255+ )
256+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
257+ weight_uint8 , inner_k_tiles
258+ )
232259 return weight_int4pack , scales_and_zeros
233260
234261 @classmethod
@@ -608,10 +635,15 @@ def load_model_and_state_dict(
608635 if load_state_dict :
609636 q , s , z = Q4_0 .unpack (t )
610637 scales_and_zeros = pack_scales_and_zeros (s , z )
611- q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
612- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
613- q_uint8 , inner_k_tiles
614- )
638+ if is_device (q .device .type , "cpu" ) and TORCH_VERSION_AT_LEAST_2_6 :
639+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
640+ q , inner_k_tiles
641+ )
642+ else :
643+ q_tmp = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
644+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
645+ q_tmp , inner_k_tiles
646+ )
615647 state_dict [f"{ fqn } .weight" ] = weight_int4pack
616648 state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
617649
@@ -623,7 +655,7 @@ def load_model_and_state_dict(
623655 in_features = in_features ,
624656 out_features = out_features ,
625657 bias = False ,
626- device = "meta " ,
658+ device = "cpu " ,
627659 groupsize = Q4_0 .groupsize ,
628660 inner_k_tiles = inner_k_tiles ,
629661 ),
0 commit comments