@@ -215,10 +215,12 @@ def __init__(
215215 scale_dtype = torch .float32 ,
216216 compression_dtype = torch .int32 ,
217217 compression_dim = 1 ,
218- gptq_perm = False ,
218+ g_idx = False ,
219219 device = "cpu" ,
220+ use_hf_format = False ,
220221 ):
221222 super ().__init__ ()
223+ self .use_hf_format = use_hf_format
222224 self .dtype = dtype
223225 if "int" not in self .dtype : # for nf4, fp4
224226 from neural_compressor .adaptor .torch_utils .weight_only import FLOAT_MAPPING , INT_MAPPING
@@ -249,69 +251,105 @@ def __init__(
249251 assert compression_dim in [0 , 1 ], (
250252 "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel."
251253 )
252- self .register_buffer (
253- "scale" ,
254- torch .zeros (
255- (out_features , math .ceil (in_features / self .groupsize )),
256- dtype = self .float_type ,
257- ).to (device ),
258- )
259- if compression_dim == 1 :
254+ if self .use_hf_format :
260255 self .register_buffer (
261- "packed_weight " ,
256+ "scales " ,
262257 torch .zeros (
263- (out_features , math .ceil (in_features / self .n_pack )),
258+ (math .ceil (in_features / self .groupsize ), out_features ),
259+ dtype = self .float_type ,
260+ ).to (device ),
261+ )
262+ self .scales = self .scales .T
263+ self .register_buffer (
264+ "qweight" ,
265+ torch .zeros (
266+ (math .ceil (in_features / self .n_pack ), out_features ),
264267 dtype = self .compressed_dtype ,
265268 ).to (device ),
266269 )
267- if zp :
268- self .register_buffer (
269- "packed_zp" ,
270- torch .zeros (
271- (self .out_features , math .ceil (self .in_features / self .groupsize / self .n_pack )),
272- dtype = self .compressed_dtype ,
273- ).to (device ),
274- )
275- else :
270+ self .qweight = self .qweight .T
276271 self .register_buffer (
277- "packed_weight " ,
272+ "qzeros " ,
278273 torch .zeros (
279- (math .ceil (out_features / self .n_pack ), in_features ),
274+ (math .ceil (self . in_features / self .groupsize ), math . ceil ( self . out_features / self . n_pack ) ),
280275 dtype = self .compressed_dtype ,
281276 ).to (device ),
282277 )
283- if zp :
278+ self .qzeros = self .qzeros .T
279+ else :
280+ self .register_buffer (
281+ "scales" ,
282+ torch .zeros (
283+ (out_features , math .ceil (in_features / self .groupsize )),
284+ dtype = self .float_type ,
285+ ).to (device ),
286+ )
287+ if compression_dim == 1 :
284288 self .register_buffer (
285- "packed_zp " ,
289+ "qweight " ,
286290 torch .zeros (
287- (math . ceil ( self . out_features / self . n_pack ) , math .ceil (self . in_features / self .groupsize )),
291+ (out_features , math .ceil (in_features / self .n_pack )),
288292 dtype = self .compressed_dtype ,
289293 ).to (device ),
290294 )
295+ if zp :
296+ self .register_buffer (
297+ "qzeros" ,
298+ torch .zeros (
299+ (self .out_features , math .ceil (self .in_features / self .groupsize / self .n_pack )),
300+ dtype = self .compressed_dtype ,
301+ ).to (device ),
302+ )
303+ else :
304+ self .register_buffer (
305+ "qweight" ,
306+ torch .zeros (
307+ (math .ceil (out_features / self .n_pack ), in_features ),
308+ dtype = self .compressed_dtype ,
309+ ).to (device ),
310+ )
311+ if zp :
312+ self .register_buffer (
313+ "qzeros" ,
314+ torch .zeros (
315+ (math .ceil (self .out_features / self .n_pack ), math .ceil (self .in_features / self .groupsize )),
316+ dtype = self .compressed_dtype ,
317+ ).to (device ),
318+ )
319+ if g_idx :
320+ self .register_buffer ("g_idx" , torch .zeros (in_features , dtype = torch .int32 ).to (device ))
321+ else :
322+ self .g_idx = None
291323 if bias :
292324 self .register_buffer ("bias" , torch .zeros (self .out_features , dtype = self .float_type ).to (device ))
293325 else :
294326 self .bias = None
295- if gptq_perm :
296- self .register_buffer ("gptq_perm" , torch .zeros (in_features , dtype = torch .int32 ).to (device ))
297- else :
298- self .gptq_perm = None
299327
300- def pack (self , int_weight , scale , zp , bias , gptq_perm = None ):
328+ def pack (self , int_weight , scale , zp , bias , g_idx = None ):
301329 int_weight = int_weight .to (self .device )
330+ if self .use_hf_format and zp is None :
331+ # to avoid overflow
332+ int_weight = int_weight .type (torch .int32 )
333+ shift_bias = 2 ** (self .bits - 1 )
334+ int_weight += shift_bias
335+ zp = torch .zeros_like (scale , dtype = torch .uint8 ) + shift_bias
302336 if bias is not None :
303337 assert hasattr (self , "bias" ), "bias is not set when initializing."
304338 self .bias = bias .type (self .float_type ).to (self .device )
305- if gptq_perm is not None :
306- assert hasattr (self , "gptq_perm" ), "gptq_perm is not set when initializing."
307- self .gptq_perm = gptq_perm .type (torch .int32 ).to (self .device )
308- assert scale .shape == self .scale .shape , "Scale shape is mismatched."
309- self .scale = scale .type (self .float_type ).to (self .device )
310- if self .compression_dim == 0 :
339+ if g_idx is not None :
340+ assert hasattr (self , "g_idx" ), "g_idx is not set when initializing."
341+ self .g_idx = g_idx .type (torch .int32 ).to (self .device )
342+ if self .use_hf_format :
343+ invperm = torch .argsort (self .g_idx )
344+ self .g_idx = invperm // self .groupsize
345+ self .g_idx = self .g_idx .type (torch .int32 ).to (self .device )
346+ assert scale .shape == self .scales .shape , "Scale shape is mismatched."
347+ self .scales = scale .type (self .float_type ).to (self .device )
348+ if not self .use_hf_format and self .compression_dim == 0 :
311349 int_weight = int_weight .T
312- self .packed_weight = self .packed_weight .T
350+ self .qweight = self .qweight .T
313351 origin_shape = int_weight .shape
314- target_shape = self .packed_weight .shape
352+ target_shape = self .qweight .shape
315353 assert origin_shape [0 ] == target_shape [0 ], "output channels mismatch, please check."
316354 mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compressed_dtype ).to (self .device )
317355
@@ -323,121 +361,112 @@ def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
323361 for e in range (tmp .shape [1 ]):
324362 tmp [:, e ] &= mask
325363 tmp [:, e ] = tmp [:, e ] << (self .bits * e )
326- self .packed_weight [:, j ] |= tmp [:, e ]
327- if self .compression_dim == 0 :
328- self .packed_weight = self .packed_weight .T
364+ self .qweight [:, j ] |= tmp [:, e ]
365+ if not self . use_hf_format and self .compression_dim == 0 :
366+ self .qweight = self .qweight .T
329367
330368 if zp is not None :
331369 zp = zp .to (self .device )
332- if self .compression_dim == 0 :
370+ if self .use_hf_format :
371+ zp -= 1
372+ if self .use_hf_format or self .compression_dim == 0 :
333373 zp = zp .T
334- self .packed_zp = self .packed_zp .T
335- assert hasattr (self , "packed_zp " ), "zp is not set when initializing."
336- target_shape = self .packed_zp .shape
374+ self .qzeros = self .qzeros .T
375+ assert hasattr (self , "qzeros " ), "zp is not set when initializing."
376+ target_shape = self .qzeros .shape
337377 for j in range (target_shape [1 ]):
338378 start = self .n_pack * j
339379 end = self .n_pack * (j + 1 )
340380 tmp = zp [:, start :end ].type (self .compressed_dtype )
341381 for e in range (tmp .shape [1 ]):
342382 tmp [:, e ] &= mask
343383 tmp [:, e ] = tmp [:, e ] << (self .bits * e )
344- self .packed_zp [:, j ] |= tmp [:, e ]
345- if self .compression_dim == 0 :
346- self .packed_zp = self .packed_zp .T
384+ self .qzeros [:, j ] |= tmp [:, e ]
385+ if self .use_hf_format or self .compression_dim == 0 :
386+ self .qzeros = self .qzeros .T
387+ if self .use_hf_format :
388+ self .scales = self .scales .T
389+ self .qweight = self .qweight .T
390+ self .g_idx = self .g_idx
391+ self .qzeros = self .qzeros .T
347392
348393 def recover (self ):
349394 logger .debug (f"Recovering { self } weight" )
350- device = self .scale .device
395+ if self .use_hf_format :
396+ # Prevent broken id links of self.scales and self.scales
397+ self .scales = self .scales .T
398+ self .qweight = self .qweight .T
399+ self .g_idx = self .g_idx
400+ self .qzeros = self .qzeros .T
401+ device = self .scales .device
402+ fp32_weight = torch .zeros (self .out_features , self .in_features , dtype = self .float_type ).to (device )
403+ if self .g_idx is None :
404+ # used for recovering fp32_weight
405+ self .g_idx = torch .tensor ([i // self .groupsize for i in range (self .in_features )], dtype = torch .int32 )
351406 mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compressed_dtype ).to (device )
352- if hasattr (self , "packed_zp " ):
407+ if hasattr (self , "qzeros " ):
353408 weight_dtype = torch .uint8
354409 else :
355410 weight_dtype = torch .int8
356411 # unpack weight
357412 weight = torch .zeros (self .out_features , self .in_features , dtype = weight_dtype ).to (device )
358- packed_weight = self .packed_weight
359- if self .compression_dim == 0 :
413+ qweight = self .qweight
414+ if not self . use_hf_format and self .compression_dim == 0 :
360415 weight = weight .T
361- packed_weight = packed_weight .T
416+ qweight = qweight .T
362417 origin_shape = weight .shape
363- target_shape = packed_weight .shape
418+ target_shape = qweight .shape
364419 for j in range (target_shape [1 ]):
365420 for e in range (self .n_pack ):
366421 index = j * self .n_pack + e
367422 if index >= origin_shape [1 ]:
368423 continue
369- tmp = packed_weight [:, j ]
424+ tmp = qweight [:, j ]
370425 tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
371426 tmp = tmp >> self .compress_bits - self .bits
372427 if weight_dtype == torch .uint8 :
373428 tmp &= mask # remove sign bit
374429 weight [:, index ] = tmp .type (weight_dtype )
375- if self .compression_dim == 0 :
430+ if not self . use_hf_format and self .compression_dim == 0 :
376431 weight = weight .T
377432 if "int" not in self .dtype :
378433 new_weight = torch .zeros (self .out_features , self .in_features ).to (device )
379434 for k , v in self .int2float_mapping .items ():
380435 new_weight += torch .where (weight == k , v , 0 )
381436 weight = new_weight
382437 # unpack zero_point
383- if hasattr (self , "packed_zp " ):
438+ if hasattr (self , "qzeros " ):
384439 zp_dtype = self .compressed_dtype # to avoid overflow when weight-zp
385- zp = torch .zeros (self .scale .shape , dtype = zp_dtype ).to (device )
386- packed_zp = self .packed_zp
387- if self .compression_dim == 0 :
440+ zp = torch .zeros (self .scales .shape , dtype = zp_dtype ).to (device )
441+ qzeros = self .qzeros
442+ if self .use_hf_format or self . compression_dim == 0 :
388443 zp = zp .T
389- packed_zp = packed_zp .T
444+ qzeros = qzeros .T
390445 origin_shape = zp .shape
391- target_shape = packed_zp .shape
446+ target_shape = qzeros .shape
392447 for j in range (target_shape [1 ]):
393448 for e in range (self .n_pack ):
394449 index = j * self .n_pack + e
395450 if index >= origin_shape [1 ]:
396451 continue
397- tmp = packed_zp [:, j ]
452+ tmp = qzeros [:, j ]
398453 tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
399454 tmp = tmp >> self .compress_bits - self .bits
400455 tmp &= mask
401456 zp [:, index ] = tmp .type (zp_dtype )
402- if self .compression_dim == 0 :
457+ if self .use_hf_format or self . compression_dim == 0 :
403458 zp = zp .T
459+ if self .use_hf_format :
460+ # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
461+ zp += 1
462+ zp = torch .where (zp > (2 ** self .bits - 1 ), 0 , zp )
404463 # recover fp32 weight with int_weight, scale, and zero_point
405- left_element = self .in_features % self .groupsize
406- if left_element != 0 :
407- split_index = self .in_features // self .groupsize * self .groupsize
408- weight1 = weight [:, :- split_index ].reshape (- 1 , self .groupsize )
409- scale1 = self .scale [:, :- 1 ].reshape (- 1 , 1 )
410- zp1 = zp [:, :- 1 ].reshape (- 1 , 1 )
411- weight1 = ((weight1 - zp1 ) * scale1 ).reshape (self .out_features , - 1 )
412- weight2 = weight [:, - split_index :]
413- scale2 = self .scale [:, - 1 :]
414- zp2 = zp [:, - 1 ].reshape (- 1 , 1 )
415- weight2 = (weight2 - zp2 ) * scale2
416- fp32_weight = torch .cat ((weight1 , weight2 ), dim = 1 )
417- else :
418- weight = weight .reshape (- 1 , self .groupsize )
419- scale = self .scale .reshape (- 1 , 1 )
420- zp = zp .reshape (- 1 , 1 )
421- fp32_weight = ((weight - zp ) * scale ).reshape (self .out_features , - 1 )
464+ for idx in range (self .in_features ):
465+ fp32_weight [:, idx ] = (weight [:, idx ] - zp [:, self .g_idx [idx ]]) * self .scales [:, self .g_idx [idx ]]
422466 else :
423467 # recover fp32 weight with int_weight, scale
424- left_element = self .in_features % self .groupsize
425- if left_element != 0 :
426- split_index = self .in_features // self .groupsize * self .groupsize
427- weight1 = weight [:, :split_index ].reshape (- 1 , self .groupsize )
428- scale1 = self .scale [:, :- 1 ].reshape (- 1 , 1 )
429- weight1 = (weight1 * scale1 ).reshape (self .out_features , - 1 )
430- weight2 = weight [:, split_index :]
431- scale2 = self .scale [:, - 1 :]
432- weight2 = weight2 * scale2
433- fp32_weight = torch .cat ((weight1 , weight2 ), dim = 1 )
434- else :
435- weight = weight .reshape (- 1 , self .groupsize )
436- scale = self .scale .reshape (- 1 , 1 )
437- fp32_weight = (weight * scale ).reshape (self .out_features , - 1 )
438- if self .gptq_perm is not None :
439- invperm = torch .argsort (self .gptq_perm )
440- fp32_weight = fp32_weight [:, invperm ]
468+ for idx in range (self .in_features ):
469+ fp32_weight [:, idx ] = weight [:, idx ] * self .scales [:, self .g_idx [idx ]]
441470 return fp32_weight
442471
443472 def forward (self , input ):
@@ -453,9 +482,16 @@ def forward(self, input):
453482 return F .linear (input , weight , self .bias )
454483
455484 def extra_repr (self ) -> str :
456- return "in_features={}, out_features={}, bits={}, group_size={}, bias={}" .format (
457- self .in_features , self .out_features , self .bits , self .groupsize , self .bias is not None
485+ tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}" .format (
486+ self .in_features ,
487+ self .out_features ,
488+ self .bits ,
489+ self .groupsize ,
490+ self .bias is not None ,
458491 )
492+ if self .use_hf_format :
493+ tmp_str += ", use_hf_format=True"
494+ return tmp_str
459495
460496
461497class FakeAffineTensorQuantFunction (Function ):
0 commit comments