Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 152 additions & 61 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: 2024-2025 [email protected]
# SPDX-License-Identifier: Apache-2.0
# Contact: [email protected], x.com/qubitium

# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq)
#
# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda),
# which itself is based on [gptq](https://github.com/IST-DASLab/gptq)

import math
import os
import sys
import threading
import time
from typing import Optional
from typing import Optional, Tuple

import numpy as np
import torch
Expand All @@ -27,7 +28,6 @@

log = setup_logger()

# TODO: move this to a locking class
# --------------------------------------------------------------------------------------
# Per-device lock registry to guard device-specific critical sections (like tensor moves)
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -77,6 +77,128 @@ def _get_device_lock(dev) -> threading.Lock:
torch.linalg.cholesky(tmp_eye)
del tmp_eye

# ==============================================================================
# FP8 → FP16 dequantization helpers (supports *_inv scales and group shapes)
# ==============================================================================

def _available_fp8_dtypes():
"""Collect FP8 dtypes present in the current torch build."""
names = [
"float8_e4m3fn", "float8_e4m3fnuz",
"float8_e5m2", "float8_e5m2fnuz",
# Some builds expose *_fast variants:
"float8_e4m3fn_fast", "float8_e5m2_fast",
# Add vendor aliases if your build exposes them:
# "float8_e8m0fnu",
]
dts = []
for n in names:
dt = getattr(torch, n, None)
if dt is not None:
dts.append(dt)
return tuple(dts)


_FP8_DTYPES = _available_fp8_dtypes()


def _is_fp8_dtype(dtype: torch.dtype) -> bool:
return any(dtype is dt for dt in _FP8_DTYPES)


def _safe_reciprocal(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
# scales should be >0; clamp for numerical safety
return 1.0 / x.clamp_min(eps)


def _broadcast_scale_like_weight(x_oi: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
"""
Broadcast scale s to match x (assumed [out, in] or [out, ...]).
Supports:
- 0D per-tensor
- 1D per-out-channel (len == out)
- 2D per-group where s.shape == [out, G] and G divides in
- exact shape match
"""
if s.ndim == 0:
return s
if s.ndim == 1 and s.numel() == x_oi.shape[0]:
# [out] -> [out, 1, ..., 1]
view = (x_oi.shape[0],) + (1,) * (x_oi.ndim - 1)
return s.view(*view)
if s.ndim == 2 and x_oi.ndim == 2 and s.shape[0] == x_oi.shape[0]:
# [out, G] -> expand across in
out, in_ = x_oi.shape
G = s.shape[1]
if in_ % G == 0:
reps = in_ // G
return s.repeat_interleave(reps, dim=1)
if s.shape == x_oi.shape:
return s
# Fallback: try to expand assuming per-out shape
if s.ndim == 1 and s.numel() == x_oi.shape[0]:
view = (x_oi.shape[0],) + (1,) * (x_oi.ndim - 1)
return s.view(*view)
# As a last resort, return as-is (may broadcast later or be ignored)
return s


def _find_fp8_scale_from_module(mod: nn.Module) -> Tuple[Optional[torch.Tensor], bool]:
"""
Try common attribute names for FP8 dequant scale.
Returns (scale_tensor_or_None, is_inverse_bool).
Prefers *_inv names if both forms exist.
"""
inv_names = (
"weight_scale_inv", "fp8_weight_scale_inv", "scale_inv",
"dequant_scale_inv",
)
for n in inv_names:
if hasattr(mod, n):
return getattr(mod, n), True

names = (
"weight_scale", "fp8_weight_scale", "dequant_scale", "scale",
)
for n in names:
if hasattr(mod, n):
return getattr(mod, n), False

# Example extension if you store (amax, scale) pairs:
# if hasattr(mod, "amax") and hasattr(mod, "scale"):
# return (mod.scale / mod.amax), False

return None, False


def _dequantize_fp8_to_fp16(
w_oi: torch.Tensor,
*,
scale: Optional[torch.Tensor],
is_inverse: bool,
prefer_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""
Convert FP8 -> FP16 and apply scale (or its inverse) in a layout-safe way.
`w_oi` is expected in [out, in] (code normalizes to that before call).
"""
x = w_oi.to(prefer_dtype)

if scale is not None:
s = scale.to(device=w_oi.device, dtype=prefer_dtype)
if is_inverse:
s = _safe_reciprocal(s)
s = _broadcast_scale_like_weight(x, s)
try:
x = x * s
except Exception:
# If shape is odd, skip scaling rather than crash
pass

return x


# ==============================================================================

def get_number_of_rows_and_cols(layer: nn.Module):
# return layer.weight.shape[0], np.prod(layer.weight.shape[1:])
Expand All @@ -93,18 +215,6 @@ def get_number_of_rows_and_cols(layer: nn.Module):

class GPTQ:
def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
# self.lock = threading.Lock()

# self.num_tied_handles = 0
# if qcfg.tied_gptq_handle is not None:
# qcfg.tied_gptq_handle.num_tied_handles += 1

# Flags indicating issues
# self.issue_zero_samples = False
# self.issue_nan_hessian = False
# self.issue_non_invertible = False

# self.W = module.weight
self.rows, self.columns = get_number_of_rows_and_cols(module)
if isinstance(module, NamedModule):
self.module = module.module
Expand Down Expand Up @@ -136,16 +246,13 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
self.fail_safe = False

self.H = torch.zeros((self.columns, self.columns),
dtype=torch.float32)
dtype=torch.float32)

@staticmethod
def _validate_module(module):
assert isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d,
transformers.Conv1D)), f"We supports only linear and convolutional layers. actual = `{module}`"

# def has_hessian_issues(self) -> bool:
# return any([self.issue_zero_samples, self.issue_nan_hessian, self.issue_non_invertible])

def create_quantizer(self, name: str) -> Quantizer:
return Quantizer(qcfg=self.qcfg, name=name)

Expand All @@ -163,35 +270,48 @@ def _mock_hessian_inverse(self, H: torch.Tensor):
return identity, damp

def _clone_module(self, copy=True, device: torch.device = None):
"""
Clone module weight to target device and normalize to [out, in].
If the weight is FP8, dequantize it to FP16 (optionally using scale or inverse scale).
Otherwise preserve float16/bfloat16; for other dtypes cast to float32 as in original code.
"""
if not device:
device = self.module.weight.data.device

# Bring raw parameter to device first (no dtype change yet)
clone = self.module.weight.data.to(copy=copy, device=device)
orig_dtype = clone.dtype

# Normalize layout to [out, in]
if isinstance(self.module, _ConvNd):
clone = clone.flatten(1)

if isinstance(self.module, transformers.pytorch_utils.Conv1D):
clone = clone.t()

return clone.float()
# FP8 path: dequantize to FP16 with (inverse) scale if present
if _is_fp8_dtype(orig_dtype):
fp8_scale, is_inv = _find_fp8_scale_from_module(self.module)
clone = _dequantize_fp8_to_fp16(
clone, scale=fp8_scale, is_inverse=is_inv, prefer_dtype=torch.float16
)
else:
# Non-FP8: keep half/bfloat16; otherwise cast to float32 like original
if clone.dtype not in (torch.float16, torch.bfloat16):
clone = clone.float()

return clone

def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
self.fwd_counter += 1

# print(f"self.module.target_device = {self.module.target_device}")

if self.fwd_inputs_buffered:
# with torch_streamCtx(self.module.target_device_stream):
# self.fwd_inputs_buffered_data.append(inp.to(device=self.module.target_device, non_blocking=True))

self.fwd_inputs_buffered_data.append(inp.to(device=self.module.target_device, non_blocking=False))
self.fwd_inputs_buffered_data.append(
inp.to(device=self.module.target_device, non_blocking=False)
)
else:
self.process_batch(inp)

def process_batch(self, inp: torch.Tensor):
# print(f"inp = {inp}")
# print(f"self.module = {self.module} device = {self.module.target_device}")
inp = inp.to(device=self.module.target_device, dtype=torch.float32)

# input reshaping
Expand Down Expand Up @@ -247,8 +367,6 @@ def process_batch(self, inp: torch.Tensor):
# update number of collected samples
self.nsamples += batch_token_size

# inp returned here is flattened/reshaped original inp
# return batch_token_size, reshaped_inp, alpha, beta
del batch_token_size, reshaped_inp, alpha, beta

# FIXME, optimum needs fasterquant, we need to remove it
Expand Down Expand Up @@ -294,7 +412,6 @@ def hessian_inverse(self, H: torch.Tensor):
try:
H2 = H.clone()
H2[diag, diag] += damp * mean
# TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal.
H2 = torch.linalg.cholesky(H2)
Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True)
del H, H2
Expand All @@ -312,7 +429,6 @@ def hessian_inverse(self, H: torch.Tensor):
if not (0 < damp < 1):
log.error(
f"Quantization: Module `{self.name}` -> `damp_percent` must between 0 and 1. current is {damp}. Module cannot be correctly processed.")
# raise ValueError(f"Quantization: `damp_percent` must between 0 and 1. current is {damp}")
return None, 1.0

return Hinv, damp
Expand All @@ -322,39 +438,25 @@ def quantize(
self,
blocksize=128,
):
# self.H = self.H.to(device=CUDA_0)
# log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`")
start = time.time()

# Temporarily disable torch.compile due to compatibility issues with torch 2.8
# Will re-enable once the issue is fixed
# if not TORCH_GTE_28 and not self.qcfg.mock_quantization:
# self.hessian_inverse = torch_compile(self.hessian_inverse)

if self.qcfg.mock_quantization:
# Use simplified hessian inverse (identity matrix)
self.hessian_inverse = self._mock_hessian_inverse

# process buffered inputs
if len(self.fwd_inputs_buffered_data) > 0:
torch_sync(device=self.module.target_device)

for inp in self.fwd_inputs_buffered_data:
self.process_batch(inp)

# release buffer
del self.fwd_inputs_buffered_data

# if self.device.type not in ["mps", "cpu"]:
# self.module.weight.data = self.module.weight.data.cpu()

# TODO: waiting for pytorch implementation of ops for MPS
if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
raise RuntimeError(
"For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.")

if self.module_copy is None:
# log.info("copy W to cuda_1")
W = self._clone_module(device=self.module.target_device)
else:
W = self.module_copy.to(device=self.module.target_device)
Expand All @@ -368,19 +470,16 @@ def quantize(
H[dead, dead] = 1
W[:, dead] = 0

# g_idx = []
scale = []
zero = []
now_idx = 1

if self.qcfg.static_groups:
import copy

groups = []
for i in range(0, self.columns, self.qcfg.group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i: (i + self.qcfg.group_size)], weight=True)

scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
Expand Down Expand Up @@ -475,7 +574,6 @@ def quantize(
if hasattr(self.quantizer, 'scale') and hasattr(self.quantizer, 'zero'):
latest_scale = self.quantizer.scale
latest_zero = self.quantizer.zero

if latest_scale.dim() == 1:
latest_scale = latest_scale.view(-1, 1)
if latest_zero.dim() == 1:
Expand Down Expand Up @@ -524,7 +622,7 @@ def quantize(
if self.qcfg.group_size != -1:
if not self.qcfg.static_groups:
if (i1 + i) % self.qcfg.group_size == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True)
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + self.qcfg.group_size)], weight=True)

if ((i1 + i) // self.qcfg.group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
Expand All @@ -534,13 +632,12 @@ def quantize(
idx = i1 + i
if self.qcfg.desc_act:
idx = perm[idx]

self.quantizer = groups[idx // self.qcfg.group_size]

q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
if Hinv is not None:
Losses1[:, i] = (w - q) ** 2 / d**2
Losses1[:, i] = (w - q) ** 2 / d ** 2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Expand All @@ -550,14 +647,10 @@ def quantize(
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

# TODO: why is there a torch_sync here? There are no streaming ops here?
# torch_sync(device=self.module.target_device)

if Hinv is not None:
del Hinv
if self.nsamples != 0:
avg_loss = torch.sum(Losses).item() / self.nsamples

if math.isnan(avg_loss):
print("Losses sum item:", torch.sum(Losses).item())
if self.fail_safe:
Expand Down Expand Up @@ -636,7 +729,5 @@ def free(self):
del self.module_copy
del self.module

# torch_empty_cache(self.device)


__all__ = ["GPTQ"]
Loading