diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 4606f0dba..ceaf3fca5 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -34,6 +34,8 @@ torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False +CPU = torch.device('cpu') +COLUMN_THRESHOLD = 7168 class GPTQ: def __init__(self, layer): @@ -42,7 +44,10 @@ def __init__(self, layer): self.layer_copy = self._clone_layer() self.rows, self.columns = self.layer_copy.shape[0], self.layer_copy.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.device) + if self.columns >= COLUMN_THRESHOLD: + self.H = torch.zeros((self.columns, self.columns), device=CPU) + else: + self.H = torch.zeros((self.columns, self.columns), device=self.device) self.nsamples = 0 self.quantizer = Quantizer() @@ -66,6 +71,9 @@ def add_batch(self, inp, out): inp = inp.unsqueeze(0) tmp = inp.shape[0] + if self.columns >= COLUMN_THRESHOLD: + inp = inp.to(CPU) + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) @@ -141,7 +149,10 @@ def quantize( if not self.quantizer.ready(): self.quantizer.find_params(W, weight=True) - H = self.H + if self.columns >= COLUMN_THRESHOLD: + H = self.H.to(CPU) + else: + H = self.H del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 @@ -176,7 +187,7 @@ def quantize( while 1 > percdamp > 0: try: damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=self.device) + diag = torch.arange(self.columns, device=H.device) H[diag, diag] += damp H = torch.linalg.cholesky(H)