@@ -454,11 +454,7 @@ def extra_repr(self) -> str:
454454 tmp_str += ", use_optimum_format=True"
455455 return tmp_str
456456
457-
458- # TODO: implement HPUWeightOnlyLinear
459- # temporarily let HPUWeightOnlyLinear inherit INCWeightOnlyLinear
460- # should be 'class HPUWeightOnlyLinear(WeightOnlyLinear)'
461- class HPUWeightOnlyLinear (INCWeightOnlyLinear ):
457+ class HPUWeightOnlyLinear (WeightOnlyLinear ):
462458 def __init__ (
463459 self ,
464460 in_features ,
@@ -468,7 +464,7 @@ def __init__(
468464 group_size = 32 ,
469465 zp = False ,
470466 bias = False ,
471- scale_dtype = torch .float32 ,
467+ scale_dtype = torch .bfloat16 ,
472468 compression_dtype = torch .int32 ,
473469 compression_dim = 1 ,
474470 g_idx = False ,
@@ -482,17 +478,128 @@ def __init__(
482478 dtype ,
483479 bits ,
484480 group_size ,
485- zp ,
486- bias ,
487- scale_dtype ,
488- compression_dtype ,
489- compression_dim ,
490- g_idx ,
491481 device ,
492- use_optimum_format ,
493- ** kwargs ,
482+ )
483+ self .float_type = torch .bfloat16
484+ self .compression_dim = compression_dim
485+ self .compression_dtype = compression_dtype
486+
487+ if bits != 4 :
488+ raise NotImplementedError ("Only 4 bits are supported." )
489+ self .maxq = 2 ** self .bits - 1
490+
491+ if bias :
492+ self .register_buffer ("bias" , torch .zeros (self .out_features , dtype = self .float_type ).to (self .device ))
493+ else :
494+ self .bias = None
495+
496+ self .register_buffer (
497+ "qweight" ,
498+ torch .zeros ((in_features , out_features // 32 * self .bits ), dtype = self .compression_dtype ).to (self .device ),
494499 )
495500
501+ self .register_buffer (
502+ "qzeros" ,
503+ torch .zeros (
504+ (
505+ math .ceil (in_features / self .group_size ),
506+ out_features // 32 * self .bits ,
507+ ),
508+ dtype = self .compression_dtype ,
509+ ),
510+ )
511+ self .register_buffer (
512+ "scales" ,
513+ torch .zeros (
514+ (math .ceil (in_features / self .group_size ), out_features ),
515+ dtype = self .float_type ,
516+ ),
517+ )
518+
519+ if g_idx :
520+ self .register_buffer (
521+ "g_idx" ,
522+ torch .tensor ([i // self .group_size for i in range (in_features )], dtype = torch .int32 ),
523+ )
524+ else :
525+ self .g_idx = None
526+
527+ self .half_indim = self .in_features // 2
528+
529+ self .wf = torch .tensor (list (range (0 , 32 , self .bits )), dtype = torch .int32 ).unsqueeze (0 )
530+
531+ def forward (self , input ):
532+ input_dtype = input .dtype
533+ output_shape = input .shape [:- 1 ] + (self .out_features ,)
534+ scales = self .scales
535+ qweight = self .qweight
536+ zeros = self .qzeros
537+ weight = torch .ops .hpu .convert_from_uint4 (qweight , scales , zeros , input_dtype )
538+ output = torch .matmul (input , weight )
539+ output = output .to (dtype = input_dtype ).reshape (
540+ output_shape
541+
542+ ) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
543+ output = output + self .bias if self .bias is not None else output
544+ return output
545+
546+
547+ def pack (self , int_weight , scales , zp , bias = None , g_idx = None ):
548+ logger .debug (f"Packing for HPU" )
549+
550+ scales = scales .T .contiguous ()
551+ qzeros = zp .T .contiguous ()
552+ qweight = int_weight .T .contiguous ()
553+
554+ self .scales = scales .to (dtype = torch .bfloat16 )
555+
556+ # weights and zp are on device from unpack, need to load to cpu for packing
557+ self .qweight = qweight .cpu ()
558+ new_qweight = self .pack_tensor (self .qweight )
559+ self .qweight = new_qweight .to ("hpu" )
560+
561+ self .qzeros = qzeros .cpu ()
562+ new_qzeros = self .pack_tensor (self .qzeros )
563+ self .qzeros = new_qzeros .to ("hpu" )
564+
565+ if bias is not None :
566+ self .bias = bias .to ("hpu" ).to (torch .bfloat16 )
567+
568+ def unpack (self ):
569+ logger .debug (f"Unpacking from HPU" )
570+ self .qweight = self .qweight .cpu ()
571+ weight = torch .bitwise_right_shift (
572+ torch .unsqueeze (self .qweight , 1 ).expand (- 1 , 32 // self .bits , - 1 ),
573+ self .wf .unsqueeze (- 1 ),
574+ ).to (torch .int16 if self .bits == 8 else torch .int8 )
575+ weight = torch .bitwise_and (weight , (2 ** self .bits ) - 1 )
576+ weight = weight .reshape ((weight .shape [0 ]* weight .shape [1 ], weight .shape [2 ]))
577+ self .qweight = self .qweight .to (self .device )
578+
579+ zeros = torch .bitwise_right_shift (
580+ torch .unsqueeze (self .qzeros , 2 ).expand (- 1 , - 1 , 32 // self .bits ),
581+ self .wf .unsqueeze (0 ),
582+ ).to (torch .int16 if self .bits == 8 else torch .int8 )
583+
584+ zeros = torch .bitwise_and (
585+ zeros , (2 ** self .bits ) - 1
586+ ).to (self .scales .dtype ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
587+ zeros = zeros + 1
588+ zeros = zeros .reshape (- 1 , 1 , zeros .shape [1 ] * zeros .shape [2 ])
589+ return weight , zeros
590+
591+ def pack_tensor (self , input , bits = 4 ):
592+ normal = input .to (torch .int32 )
593+ q = torch .zeros ((normal .shape [0 ], normal .shape [1 ] // 32 * bits ), dtype = torch .int32 )
594+ i = 0
595+ col = 0
596+ while col < q .shape [1 ]:
597+ for j in range (i , i + (32 // bits )):
598+ q [:, col ] |= normal [:, j ] << (bits * (j - i ))
599+ i += 32 // bits
600+ col += 1
601+ q = q .to (torch .int32 )
602+ return q
496603
497604class FakeAffineTensorQuantFunction (Function ):
498605 """Fake version of affine quantization."""
0 commit comments