1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717
18+ import gc
1819import math
1920import random
2021import re
@@ -175,6 +176,7 @@ def __init__(
175176 use_max_length = True ,
176177 pad_max_length = 2048 ,
177178 device = None ,
179+ layer_wise = False ,
178180 ):
179181 """
180182 Args:
@@ -215,9 +217,13 @@ def __init__(
215217 self .check_layer_config ()
216218
217219 # device
218- self .device = model .device
220+ self .device = device
221+ if str (self .model .device ).startswith ("cuda" ):
222+ self .device = self .model .device
219223 self .is_ready = False
220224
225+ self .layer_wise = layer_wise
226+
221227 # dataloader
222228 self .use_max_length = use_max_length
223229 self .pad_max_length = pad_max_length
@@ -438,11 +444,13 @@ def forward(layer, *args, **kwargs):
438444 raise ValueError
439445
440446 # Step1: fetch the embeddings and other layers before the transformer stack.
441- for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
442- embedding_layer = embedding_layer .to (self .device )
447+ if not self .layer_wise :
448+ for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
449+ embedding_layer = embedding_layer .to (self .device )
443450
444451 # Step2: modify the first transformer block's forward function to obtain inputs for calibration
445- self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].to (self .device )
452+ if not self .layer_wise :
453+ self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].to (self .device )
446454 forward_cache = self .gptq_related_blocks ["transformers" ][0 ].forward
447455 self .gptq_related_blocks ["transformers" ][0 ].forward = partial (
448456 forward , self .gptq_related_blocks ["transformers" ][0 ]
@@ -451,7 +459,8 @@ def forward(layer, *args, **kwargs):
451459 # Step3: run forward to obtain calibration datasets
452460 logger .info ("Collecting calibration inputs..." )
453461 for batch in tqdm (self .dataloader ):
454- batch = move_input_to_device (batch , self .device )
462+ if not self .layer_wise :
463+ batch = move_input_to_device (batch , self .device )
455464 try :
456465 if isinstance (batch , tuple ) or isinstance (batch , list ):
457466 self .model (batch [0 ])
@@ -473,9 +482,10 @@ def forward(layer, *args, **kwargs):
473482
474483 # Step 4: restore original forward function, relocate layers back to cpu.
475484 self .gptq_related_blocks ["transformers" ][0 ].forward = forward_cache
476- self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].cpu ()
477- for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
478- embedding_layer .to (self .device )
485+ if not self .layer_wise :
486+ self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].cpu ()
487+ for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
488+ embedding_layer .to (self .device )
479489 torch .cuda .empty_cache ()
480490 # end
481491 logger .info ("GPTQ quantization prepared." )
@@ -501,7 +511,7 @@ def update_blockwise_hidden_states(self, outs):
501511 self .cache_positional_arguments [0 ] = outs [:]
502512
503513 @torch .no_grad ()
504- def execute_quantization (self , means = None , stds = None ):
514+ def execute_quantization (self , means = None , stds = None , model_path = None ):
505515 """Run quantization."""
506516 # Step1: prepare quantization (calibration datasets)
507517
@@ -513,7 +523,11 @@ def execute_quantization(self, means=None, stds=None):
513523 tblock_length = len (self .gptq_related_blocks ["transformers" ])
514524 for block_idx in range (tblock_length ):
515525 logger .info (f"Quantizing layer { block_idx + 1 } / { tblock_length } .." )
516- transformer_block = self .gptq_related_blocks ["transformers" ][block_idx ].to (self .device )
526+ if not self .layer_wise :
527+ # if we do not apply layer-wise feature, we still place the entire block on the GPU
528+ transformer_block = self .gptq_related_blocks ["transformers" ][block_idx ].to (self .device )
529+ else :
530+ transformer_block = self .gptq_related_blocks ["transformers" ][block_idx ] # .to(self.device)
517531 # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized.
518532 sub_layers = find_layers (transformer_block )
519533 sub_layers_to_quant = {}
@@ -534,8 +548,16 @@ def execute_quantization(self, means=None, stds=None):
534548 # weight_config_this_layer = self.weight_config.get(
535549 # self.get_full_layer_name(layer_name, block_idx), None
536550 # )
537- weight_config_this_layer = self .get_layer_config (self .get_full_layer_name (layer_name , block_idx ))
538- gptq_for_this_block [layer_name ] = GPTQ (sub_layers [layer_name ])
551+ full_layer_name = self .get_full_layer_name (layer_name , block_idx )
552+ weight_config_this_layer = self .get_layer_config (full_layer_name )
553+ if self .layer_wise :
554+ from ..torch_utils .layer_wise_quant .utils import load_value
555+
556+ W = load_value (self .model , full_layer_name + ".weight" , model_path )
557+ else :
558+ W = sub_layers [layer_name ].weight .data .clone ()
559+
560+ gptq_for_this_block [layer_name ] = GPTQ (sub_layers [layer_name ], W , self .device )
539561 # gptq_for_this_block[layer_name].quantizer = Quantizer()
540562 gptq_for_this_block [layer_name ].quantizer .configure (
541563 weight_config_this_layer ["wbits" ],
@@ -555,7 +577,6 @@ def tmp(_, inp, out):
555577 for layer_name in sub_layers :
556578 handles .append (sub_layers [layer_name ].register_forward_hook (add_batch (layer_name )))
557579 idx = self .cache_key_arguments .pop ("i" )
558- # import pdb;pdb.set_trace()
559580 for j in range (len (self .dataloader )):
560581 cache_keyword_batch = self .gather_single_batch_from_dict (self .cache_key_arguments , j )
561582 cache_positional_batch = self .gather_single_batch_from_list (self .cache_positional_arguments , j )
@@ -570,12 +591,44 @@ def tmp(_, inp, out):
570591 # )
571592 weight_config_this_layer = self .get_layer_config (self .get_full_layer_name (layer_name , block_idx ))
572593 logger .info (f"Quantizing layer { layer_name } " )
573- scale , zp = gptq_for_this_block [layer_name ].fasterquant (
594+ if self .layer_wise :
595+ from ..torch_utils .layer_wise_quant .utils import load_value
596+
597+ full_layer_name = self .get_full_layer_name (layer_name , block_idx )
598+ W = load_value (self .model , full_layer_name + ".weight" , model_path )
599+ else :
600+ W = sub_layers [layer_name ].weight .data .clone ()
601+ scale , zp , Q = gptq_for_this_block [layer_name ].fasterquant (
602+ W ,
574603 blocksize = weight_config_this_layer ["block_size" ],
575604 percdamp = weight_config_this_layer ["percdamp" ],
576605 groupsize = weight_config_this_layer ["group_size" ],
577606 act_order = weight_config_this_layer ["act_order" ],
578607 )
608+ if self .layer_wise :
609+ from ..torch_utils .layer_wise_quant .utils import (
610+ LWQ_WORKSPACE ,
611+ clean_module_weight ,
612+ load_value ,
613+ set_module_tensor_to_device ,
614+ )
615+
616+ sub_layer = sub_layers [layer_name ]
617+ full_layer_name = self .get_full_layer_name (layer_name , block_idx )
618+ for n , p in sub_layer .named_parameters ():
619+ param_name = full_layer_name + "." + n
620+ if n == "weight" :
621+ set_module_tensor_to_device (self .model , param_name , self .device , Q )
622+ else :
623+ value = load_value (self .model , param_name , model_path )
624+ set_module_tensor_to_device (self .model , param_name , self .device , value )
625+ # sub_layer.weight.data = Q
626+ torch .save (sub_layer .state_dict (), LWQ_WORKSPACE + f"/{ full_layer_name } .pt" )
627+ clean_module_weight (sub_layer )
628+ del Q
629+ gc .collect ()
630+ else :
631+ sub_layers [layer_name ].weight .data = Q
579632 gptq_config [self .get_full_layer_name (layer_name , block_idx )] = {"scale" : scale }
580633 if not weight_config_this_layer ["sym" ]:
581634 gptq_config [self .get_full_layer_name (layer_name , block_idx )]["zero" ] = zp
@@ -594,7 +647,10 @@ def tmp(_, inp, out):
594647 out = transformer_block (* cache_positional_batch , ** cache_keyword_batch )[0 ]
595648 outs .append (out )
596649 self .cache_key_arguments ["i" ] = idx
597- self .gptq_related_blocks ["transformers" ][block_idx ] = transformer_block .cpu ()
650+ if self .layer_wise :
651+ self .gptq_related_blocks ["transformers" ][block_idx ] = transformer_block
652+ else :
653+ self .gptq_related_blocks ["transformers" ][block_idx ] = transformer_block .cpu ()
598654 del gptq_for_this_block
599655 torch .cuda .empty_cache ()
600656 # iteratively replace the input with output, thus layerwise quantization can continue.
@@ -617,10 +673,10 @@ class GPTQ:
617673 GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323)
618674 """
619675
620- def __init__ (self , layer ):
676+ def __init__ (self , layer , W , device = "cpu" ):
621677 self .layer = layer
622- self .device = self . layer . weight . device
623- W = layer .weight .data .clone ()
678+ self .device = device
679+ # W = layer.weight.data.clone()
624680 if isinstance (self .layer , nn .Conv2d ) or isinstance (self .layer , nn .Conv1d ):
625681 W = W .flatten (1 )
626682 if isinstance (self .layer , transformers .Conv1D ):
@@ -661,8 +717,9 @@ def add_batch(self, inp, out):
661717 # self.H += 2 / self.nsamples * inp.matmul(inp.t())
662718 self .H += inp .matmul (inp .t ()) # H = X*X, which should be a sysm matrix
663719
664- def fasterquant (self , blocksize = 128 , percdamp = 0.01 , groupsize = - 1 , act_order = False ):
665- W = self .layer .weight .data .clone ()
720+ def fasterquant (self , W , blocksize = 128 , percdamp = 0.01 , groupsize = - 1 , act_order = False ):
721+ # W = self.layer.weight.data.clone()
722+ weight_shape , weight_dtype = W .shape , W .data .dtype
666723 if isinstance (self .layer , nn .Conv2d ):
667724 W = W .flatten (1 )
668725 if isinstance (self .layer , transformers .Conv1D ):
@@ -740,7 +797,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
740797 # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")
741798 # logger.info(f"{torch.sum(Losses)}")
742799
743- if self .device != torch . device ( "cpu " ):
800+ if str ( self .device ). startswith ( "cuda " ):
744801 torch .cuda .synchronize ()
745802 logger .info (f"time { (time .time () - tick )} " )
746803 logger .info (f"error { torch .sum (Losses ).item ()} " )
@@ -751,7 +808,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
751808
752809 if isinstance (self .layer , transformers .Conv1D ):
753810 Q = Q .t ()
754- self .layer .weight .data = Q .reshape (self .layer .weight .shape ).to (self .layer .weight .data .dtype )
811+ # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
812+ Q = Q .reshape (weight_shape ).to (weight_dtype )
755813 if DEBUG :
756814 logger .info (f"{ torch .sum ((self .layer (self .inp1 ) - self .out1 ) ** 2 )} " )
757815
@@ -760,7 +818,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
760818 zero .append (self .quantizer .zero )
761819 scale = torch .cat (scale , dim = 1 )
762820 zero = torch .cat (zero , dim = 1 )
763- return scale , zero
821+ return scale , zero , Q
764822
765823 def free (self ):
766824 if DEBUG :
0 commit comments