From 49fe28eef03dc988433ed5bc4127427f1f32f2df Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 1 Feb 2025 14:54:34 +0000 Subject: [PATCH 1/2] reduce memory --- gptqmodel/quantization/gptq.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 4606f0dba..ca41bb9f2 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -34,6 +34,8 @@ torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False +CPU = torch.device('cpu') +COLUMN_THRESHOLD = 7168 class GPTQ: def __init__(self, layer): @@ -42,7 +44,10 @@ def __init__(self, layer): self.layer_copy = self._clone_layer() self.rows, self.columns = self.layer_copy.shape[0], self.layer_copy.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.device) + if self.columns >= COLUMN_THRESHOLD: + self.H = torch.zeros((self.columns, self.columns), device=CPU) + else: + self.H = torch.zeros((self.columns, self.columns), device=self.device) self.nsamples = 0 self.quantizer = Quantizer() @@ -66,6 +71,9 @@ def add_batch(self, inp, out): inp = inp.unsqueeze(0) tmp = inp.shape[0] + if self.columns >= COLUMN_THRESHOLD: + inp = inp.to(CPU) + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) @@ -141,7 +149,10 @@ def quantize( if not self.quantizer.ready(): self.quantizer.find_params(W, weight=True) - H = self.H + if self.columns >= COLUMN_THRESHOLD: + H = self.H.to(CPU) + else: + H = self.H del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 From 9bec90da3b786b13957c61ea485a6aa884bcdb20 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 1 Feb 2025 15:12:59 +0000 Subject: [PATCH 2/2] wrong device --- gptqmodel/quantization/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index ca41bb9f2..ceaf3fca5 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -187,7 +187,7 @@ def quantize( while 1 > percdamp > 0: try: damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=self.device) + diag = torch.arange(self.columns, device=H.device) H[diag, diag] += damp H = torch.linalg.cholesky(H)