@@ -315,6 +315,7 @@ def pack(self, int_weight, scales, zp, bias=None, g_idx=None, **kwargs):
315315 self .qzeros = self .qzeros .T .contiguous ()
316316
317317 def unpack (self ):
318+ """Unpack weight and zero point."""
318319 scales = self .scales .T .contiguous () if self .use_optimum_format else self .scales
319320 qweight = self .qweight .T .contiguous () if self .use_optimum_format else self .qweight
320321
@@ -354,6 +355,7 @@ def unpack(self):
354355 return UnpackedWeightOnlyLinearParams (weight , scales , zp , g_idx = self .g_idx , bias = self .bias )
355356
356357 def recover (self ):
358+ """Recover fp32 weight from packed weight."""
357359 logger .debug (f"Recovering { self } weight" )
358360 unpack_params_dict = self .unpack ()
359361 weight = unpack_params_dict .get ("int_weight" )
@@ -379,6 +381,7 @@ def recover(self):
379381 return fp32_weight .to (scales .device )
380382
381383 def forward (self , input ):
384+ """Forward function."""
382385 if not hasattr (self , "weight" ):
383386 weight = self .recover ()
384387 device = self .scales .device
@@ -396,12 +399,14 @@ def forward(self, input):
396399 return F .linear (input , weight , self .bias )
397400
398401 def pack_tensor (self , raw_tensor ):
402+ """Pack tensor."""
399403 if "cuda" in self .device :
400404 return self .pack_tensor_with_torch (raw_tensor )
401405 else :
402406 return self .pack_tensor_with_numpy (raw_tensor )
403407
404408 def unpack_tensor (self , packed_tensor ):
409+ """Unpack tensor."""
405410 if "cuda" in self .device :
406411 return self .unpack_tensor_with_torch (packed_tensor )
407412 else :
0 commit comments