diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 37ef9d3e8..7d536c270 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - -# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq) +# +# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), +# which itself is based on [gptq](https://github.com/IST-DASLab/gptq) import math import os import sys import threading import time -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch @@ -27,7 +28,6 @@ log = setup_logger() -# TODO: move this to a locking class # -------------------------------------------------------------------------------------- # Per-device lock registry to guard device-specific critical sections (like tensor moves) # -------------------------------------------------------------------------------------- @@ -77,6 +77,128 @@ def _get_device_lock(dev) -> threading.Lock: torch.linalg.cholesky(tmp_eye) del tmp_eye +# ============================================================================== +# FP8 → FP16 dequantization helpers (supports *_inv scales and group shapes) +# ============================================================================== + +def _available_fp8_dtypes(): + """Collect FP8 dtypes present in the current torch build.""" + names = [ + "float8_e4m3fn", "float8_e4m3fnuz", + "float8_e5m2", "float8_e5m2fnuz", + # Some builds expose *_fast variants: + "float8_e4m3fn_fast", "float8_e5m2_fast", + # Add vendor aliases if your build exposes them: + # "float8_e8m0fnu", + ] + dts = [] + for n in names: + dt = getattr(torch, n, None) + if dt is not None: + dts.append(dt) + return tuple(dts) + + +_FP8_DTYPES = _available_fp8_dtypes() + + +def _is_fp8_dtype(dtype: torch.dtype) -> bool: + return any(dtype is dt for dt in _FP8_DTYPES) + + +def _safe_reciprocal(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + # scales should be >0; clamp for numerical safety + return 1.0 / x.clamp_min(eps) + + +def _broadcast_scale_like_weight(x_oi: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + """ + Broadcast scale s to match x (assumed [out, in] or [out, ...]). + Supports: + - 0D per-tensor + - 1D per-out-channel (len == out) + - 2D per-group where s.shape == [out, G] and G divides in + - exact shape match + """ + if s.ndim == 0: + return s + if s.ndim == 1 and s.numel() == x_oi.shape[0]: + # [out] -> [out, 1, ..., 1] + view = (x_oi.shape[0],) + (1,) * (x_oi.ndim - 1) + return s.view(*view) + if s.ndim == 2 and x_oi.ndim == 2 and s.shape[0] == x_oi.shape[0]: + # [out, G] -> expand across in + out, in_ = x_oi.shape + G = s.shape[1] + if in_ % G == 0: + reps = in_ // G + return s.repeat_interleave(reps, dim=1) + if s.shape == x_oi.shape: + return s + # Fallback: try to expand assuming per-out shape + if s.ndim == 1 and s.numel() == x_oi.shape[0]: + view = (x_oi.shape[0],) + (1,) * (x_oi.ndim - 1) + return s.view(*view) + # As a last resort, return as-is (may broadcast later or be ignored) + return s + + +def _find_fp8_scale_from_module(mod: nn.Module) -> Tuple[Optional[torch.Tensor], bool]: + """ + Try common attribute names for FP8 dequant scale. + Returns (scale_tensor_or_None, is_inverse_bool). + Prefers *_inv names if both forms exist. + """ + inv_names = ( + "weight_scale_inv", "fp8_weight_scale_inv", "scale_inv", + "dequant_scale_inv", + ) + for n in inv_names: + if hasattr(mod, n): + return getattr(mod, n), True + + names = ( + "weight_scale", "fp8_weight_scale", "dequant_scale", "scale", + ) + for n in names: + if hasattr(mod, n): + return getattr(mod, n), False + + # Example extension if you store (amax, scale) pairs: + # if hasattr(mod, "amax") and hasattr(mod, "scale"): + # return (mod.scale / mod.amax), False + + return None, False + + +def _dequantize_fp8_to_fp16( + w_oi: torch.Tensor, + *, + scale: Optional[torch.Tensor], + is_inverse: bool, + prefer_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """ + Convert FP8 -> FP16 and apply scale (or its inverse) in a layout-safe way. + `w_oi` is expected in [out, in] (code normalizes to that before call). + """ + x = w_oi.to(prefer_dtype) + + if scale is not None: + s = scale.to(device=w_oi.device, dtype=prefer_dtype) + if is_inverse: + s = _safe_reciprocal(s) + s = _broadcast_scale_like_weight(x, s) + try: + x = x * s + except Exception: + # If shape is odd, skip scaling rather than crash + pass + + return x + + +# ============================================================================== def get_number_of_rows_and_cols(layer: nn.Module): # return layer.weight.shape[0], np.prod(layer.weight.shape[1:]) @@ -93,18 +215,6 @@ def get_number_of_rows_and_cols(layer: nn.Module): class GPTQ: def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): - # self.lock = threading.Lock() - - # self.num_tied_handles = 0 - # if qcfg.tied_gptq_handle is not None: - # qcfg.tied_gptq_handle.num_tied_handles += 1 - - # Flags indicating issues - # self.issue_zero_samples = False - # self.issue_nan_hessian = False - # self.issue_non_invertible = False - - # self.W = module.weight self.rows, self.columns = get_number_of_rows_and_cols(module) if isinstance(module, NamedModule): self.module = module.module @@ -136,16 +246,13 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self.fail_safe = False self.H = torch.zeros((self.columns, self.columns), - dtype=torch.float32) + dtype=torch.float32) @staticmethod def _validate_module(module): assert isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, transformers.Conv1D)), f"We supports only linear and convolutional layers. actual = `{module}`" - # def has_hessian_issues(self) -> bool: - # return any([self.issue_zero_samples, self.issue_nan_hessian, self.issue_non_invertible]) - def create_quantizer(self, name: str) -> Quantizer: return Quantizer(qcfg=self.qcfg, name=name) @@ -163,35 +270,48 @@ def _mock_hessian_inverse(self, H: torch.Tensor): return identity, damp def _clone_module(self, copy=True, device: torch.device = None): + """ + Clone module weight to target device and normalize to [out, in]. + If the weight is FP8, dequantize it to FP16 (optionally using scale or inverse scale). + Otherwise preserve float16/bfloat16; for other dtypes cast to float32 as in original code. + """ if not device: device = self.module.weight.data.device + # Bring raw parameter to device first (no dtype change yet) clone = self.module.weight.data.to(copy=copy, device=device) + orig_dtype = clone.dtype + # Normalize layout to [out, in] if isinstance(self.module, _ConvNd): clone = clone.flatten(1) if isinstance(self.module, transformers.pytorch_utils.Conv1D): clone = clone.t() - return clone.float() + # FP8 path: dequantize to FP16 with (inverse) scale if present + if _is_fp8_dtype(orig_dtype): + fp8_scale, is_inv = _find_fp8_scale_from_module(self.module) + clone = _dequantize_fp8_to_fp16( + clone, scale=fp8_scale, is_inverse=is_inv, prefer_dtype=torch.float16 + ) + else: + # Non-FP8: keep half/bfloat16; otherwise cast to float32 like original + if clone.dtype not in (torch.float16, torch.bfloat16): + clone = clone.float() + + return clone def add_batch(self, inp: torch.Tensor, out: torch.Tensor): self.fwd_counter += 1 - - # print(f"self.module.target_device = {self.module.target_device}") - if self.fwd_inputs_buffered: - # with torch_streamCtx(self.module.target_device_stream): - # self.fwd_inputs_buffered_data.append(inp.to(device=self.module.target_device, non_blocking=True)) - - self.fwd_inputs_buffered_data.append(inp.to(device=self.module.target_device, non_blocking=False)) + self.fwd_inputs_buffered_data.append( + inp.to(device=self.module.target_device, non_blocking=False) + ) else: self.process_batch(inp) def process_batch(self, inp: torch.Tensor): - # print(f"inp = {inp}") - # print(f"self.module = {self.module} device = {self.module.target_device}") inp = inp.to(device=self.module.target_device, dtype=torch.float32) # input reshaping @@ -247,8 +367,6 @@ def process_batch(self, inp: torch.Tensor): # update number of collected samples self.nsamples += batch_token_size - # inp returned here is flattened/reshaped original inp - # return batch_token_size, reshaped_inp, alpha, beta del batch_token_size, reshaped_inp, alpha, beta # FIXME, optimum needs fasterquant, we need to remove it @@ -294,7 +412,6 @@ def hessian_inverse(self, H: torch.Tensor): try: H2 = H.clone() H2[diag, diag] += damp * mean - # TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal. H2 = torch.linalg.cholesky(H2) Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True) del H, H2 @@ -312,7 +429,6 @@ def hessian_inverse(self, H: torch.Tensor): if not (0 < damp < 1): log.error( f"Quantization: Module `{self.name}` -> `damp_percent` must between 0 and 1. current is {damp}. Module cannot be correctly processed.") - # raise ValueError(f"Quantization: `damp_percent` must between 0 and 1. current is {damp}") return None, 1.0 return Hinv, damp @@ -322,15 +438,8 @@ def quantize( self, blocksize=128, ): - # self.H = self.H.to(device=CUDA_0) - # log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`") start = time.time() - # Temporarily disable torch.compile due to compatibility issues with torch 2.8 - # Will re-enable once the issue is fixed - # if not TORCH_GTE_28 and not self.qcfg.mock_quantization: - # self.hessian_inverse = torch_compile(self.hessian_inverse) - if self.qcfg.mock_quantization: # Use simplified hessian inverse (identity matrix) self.hessian_inverse = self._mock_hessian_inverse @@ -338,23 +447,16 @@ def quantize( # process buffered inputs if len(self.fwd_inputs_buffered_data) > 0: torch_sync(device=self.module.target_device) - for inp in self.fwd_inputs_buffered_data: self.process_batch(inp) - # release buffer del self.fwd_inputs_buffered_data - # if self.device.type not in ["mps", "cpu"]: - # self.module.weight.data = self.module.weight.data.cpu() - - # TODO: waiting for pytorch implementation of ops for MPS if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1": raise RuntimeError( "For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.") if self.module_copy is None: - # log.info("copy W to cuda_1") W = self._clone_module(device=self.module.target_device) else: W = self.module_copy.to(device=self.module.target_device) @@ -368,19 +470,16 @@ def quantize( H[dead, dead] = 1 W[:, dead] = 0 - # g_idx = [] scale = [] zero = [] now_idx = 1 if self.qcfg.static_groups: import copy - groups = [] for i in range(0, self.columns, self.qcfg.group_size): quantizer = copy.deepcopy(self.quantizer) quantizer.find_params(W[:, i: (i + self.qcfg.group_size)], weight=True) - scale.append(quantizer.scale) zero.append(quantizer.zero) groups.append(quantizer) @@ -475,7 +574,6 @@ def quantize( if hasattr(self.quantizer, 'scale') and hasattr(self.quantizer, 'zero'): latest_scale = self.quantizer.scale latest_zero = self.quantizer.zero - if latest_scale.dim() == 1: latest_scale = latest_scale.view(-1, 1) if latest_zero.dim() == 1: @@ -524,7 +622,7 @@ def quantize( if self.qcfg.group_size != -1: if not self.qcfg.static_groups: if (i1 + i) % self.qcfg.group_size == 0: - self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True) + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + self.qcfg.group_size)], weight=True) if ((i1 + i) // self.qcfg.group_size) - now_idx == -1: scale.append(self.quantizer.scale) @@ -534,13 +632,12 @@ def quantize( idx = i1 + i if self.qcfg.desc_act: idx = perm[idx] - self.quantizer = groups[idx // self.qcfg.group_size] q = self.quantizer.quantize(w.unsqueeze(1)).flatten() Q1[:, i] = q if Hinv is not None: - Losses1[:, i] = (w - q) ** 2 / d**2 + Losses1[:, i] = (w - q) ** 2 / d ** 2 err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) Err1[:, i] = err1 @@ -550,14 +647,10 @@ def quantize( Losses[:, i1:i2] = Losses1 / 2 W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - # TODO: why is there a torch_sync here? There are no streaming ops here? - # torch_sync(device=self.module.target_device) - if Hinv is not None: del Hinv if self.nsamples != 0: avg_loss = torch.sum(Losses).item() / self.nsamples - if math.isnan(avg_loss): print("Losses sum item:", torch.sum(Losses).item()) if self.fail_safe: @@ -636,7 +729,5 @@ def free(self): del self.module_copy del self.module - # torch_empty_cache(self.device) - __all__ = ["GPTQ"] diff --git a/tests/test_fp8.py b/tests/test_fp8.py new file mode 100644 index 000000000..9815b5ed2 --- /dev/null +++ b/tests/test_fp8.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import re +import pytest +import torch +import torch.nn as nn + +from gptqmodel.quantization.gptq import GPTQ + + +# ------------------------ Backend / HW detection ------------------------ + +def _is_cuda_build() -> bool: + return torch.cuda.is_available() and torch.version.cuda is not None + +def _is_rocm_build() -> bool: + return torch.cuda.is_available() and torch.version.hip is not None + +def _nvidia_sm() -> int | None: + if not _is_cuda_build(): + return None + try: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + except Exception: + return None + +def _is_hopper_or_newer() -> bool: + sm = _nvidia_sm() + return sm is not None and sm >= 90 # SM90 = Hopper/Blackwell era + +def _is_mi300_or_newer() -> bool: + if not _is_rocm_build(): + return False + try: + name = torch.cuda.get_device_name(None) + except Exception: + return False + return bool(re.search(r"MI3\d{2}", name.upper())) + +BACKEND_SUPPORTED = (_is_hopper_or_newer() or _is_mi300_or_newer()) + +# ------------------------ FP8 dtype inventory ------------------------ + +def _available_fp8_dtypes(): + names = [ + "float8_e4m3fn", "float8_e4m3fnuz", + "float8_e5m2", "float8_e5m2fnuz", + "float8_e4m3fn_fast", "float8_e5m2_fast", + ] + dts = [] + for n in names: + dt = getattr(torch, n, None) + if dt is not None: + dts.append(dt) + return tuple(dts) + +FP8_DTYPES = _available_fp8_dtypes() + +pytestmark = [ + pytest.mark.skipif(not BACKEND_SUPPORTED, reason="Need SM90+ or MI300-class for HW FP8 path."), + pytest.mark.skipif(len(FP8_DTYPES) == 0, reason="This PyTorch build exposes no FP8 dtypes."), +] + +def _pick_fp8_dtype_prefer_e4m3(): + for name in ["float8_e4m3fn", "float8_e4m3fnuz", "float8_e4m3fn_fast"]: + dt = getattr(torch, name, None) + if dt is not None: + return dt + return FP8_DTYPES[0] + +# ------------------------ Utilities ------------------------ + +def _device() -> torch.device: + idx = torch.cuda.current_device() + return torch.device("cuda", idx) + +def _mk_linear(out_features: int, in_features: int, device: torch.device) -> nn.Linear: + # We'll replace .weight param anyway. + return nn.Linear(in_features, out_features, bias=False, device=device, dtype=torch.float16) + +def _expand_scale_like_out_in(base_oi: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Broadcast scale to [out, in] for reference building, matching gptq logic.""" + s = scale.to(device=base_oi.device, dtype=base_oi.dtype) + if s.ndim == 0: + return s + if s.ndim == 1 and s.numel() == base_oi.shape[0]: + return s.view(base_oi.shape[0], *([1] * (base_oi.ndim - 1))) + if s.ndim == 2 and s.shape[0] == base_oi.shape[0]: + out, in_ = base_oi.shape + G = s.shape[1] + assert in_ % G == 0, "in_features must be divisible by group count" + reps = in_ // G + return s.repeat_interleave(reps, dim=1) + if s.shape == base_oi.shape: + return s + raise AssertionError(f"Unsupported scale shape {tuple(s.shape)} for base {tuple(base_oi.shape)}") + +def _mk_fp8_weight_from_base(base_fp16: torch.Tensor, scale: torch.Tensor | float, fp8_dtype: torch.dtype) -> torch.Tensor: + """Simulate packing: divide by scale, cast to FP8.""" + if isinstance(scale, torch.Tensor): + s = _expand_scale_like_out_in(base_fp16, scale) + scaled = base_fp16 / s + else: + scaled = base_fp16 / float(scale) + return scaled.to(fp8_dtype) + +def _reference_dequant_from_fp8_param( + w_fp8: torch.Tensor, + *, + fp8_dtype: torch.dtype, + scale: torch.Tensor | float | None, + is_inverse: bool, +) -> torch.Tensor: + """ + Build the *exact* expected dequant result: + expected = (w_fp8.to(fp16)) * (scale or 1/scale_inv) [with proper broadcasting] + """ + ref = w_fp8.to(torch.float16) + if scale is not None: + if isinstance(scale, torch.Tensor): + s = scale.to(device=w_fp8.device, dtype=torch.float16) + else: + s = torch.tensor(scale, device=w_fp8.device, dtype=torch.float16) + if is_inverse: + # avoid divide-by-zero surprises + s = 1.0 / s.clamp_min(1e-8) + s = _expand_scale_like_out_in(ref, s) if s.ndim > 0 else s + ref = ref * s + return ref + +def _assert_eq_close(a: torch.Tensor, b: torch.Tensor, msg=""): + # very tight atol because both sides are the *same* FP8 round-trip graph + a32 = a.detach().cpu().to(torch.float32) + b32 = b.detach().cpu().to(torch.float32) + assert torch.allclose(a32, b32, rtol=0.0, atol=1e-6), ( + f"{msg} | max|diff|={float((a32-b32).abs().max())}" + ) + +# ------------------------ Tests ------------------------ + +@pytest.mark.parametrize("fp8_dtype", [_pick_fp8_dtype_prefer_e4m3(), *_available_fp8_dtypes()]) +@pytest.mark.parametrize("out_features,in_features", [(6, 8), (8, 16)]) +def test_fp8_per_tensor_scale_inv_hw(fp8_dtype, out_features, in_features): + dev = _device() + torch.cuda.set_device(dev.index or 0) + lin = _mk_linear(out_features, in_features, device=dev) + + torch.manual_seed(0) + base = torch.randn(out_features, in_features, device=dev, dtype=torch.float16) + + scale = torch.tensor(0.25, device=dev, dtype=torch.float16) + w_fp8 = _mk_fp8_weight_from_base(base, scale, fp8_dtype) + + with torch.no_grad(): + lin.weight = nn.Parameter(w_fp8) # keep FP8 dtype + lin.weight_scale_inv = (1.0 / scale).to(torch.float16) + + g = GPTQ(lin) + W = g._clone_module(device=dev) + + expected = _reference_dequant_from_fp8_param( + w_fp8, fp8_dtype=fp8_dtype, scale=lin.weight_scale_inv, is_inverse=True + ) + _assert_eq_close(W, expected, "per-tensor inverse scale dequant mismatch") + + +@pytest.mark.parametrize("fp8_dtype", [_pick_fp8_dtype_prefer_e4m3(), *_available_fp8_dtypes()]) +@pytest.mark.parametrize("out_features,in_features", [(6, 8), (10, 12)]) +def test_fp8_per_channel_scale_inv_hw(fp8_dtype, out_features, in_features): + dev = _device() + torch.cuda.set_device(dev.index or 0) + lin = _mk_linear(out_features, in_features, device=dev) + + torch.manual_seed(1) + base = torch.randn(out_features, in_features, device=dev, dtype=torch.float16) + + scale = torch.linspace(0.2, 0.6, out_features, device=dev, dtype=torch.float16) + w_fp8 = _mk_fp8_weight_from_base(base, scale, fp8_dtype) + + with torch.no_grad(): + lin.weight = nn.Parameter(w_fp8) + lin.weight_scale_inv = (1.0 / scale).to(torch.float16) + + g = GPTQ(lin) + W = g._clone_module(device=dev) + + expected = _reference_dequant_from_fp8_param( + w_fp8, fp8_dtype=fp8_dtype, scale=lin.weight_scale_inv, is_inverse=True + ) + _assert_eq_close(W, expected, "per-channel inverse scale dequant mismatch") + + +@pytest.mark.parametrize("fp8_dtype", [_pick_fp8_dtype_prefer_e4m3(), *_available_fp8_dtypes()]) +@pytest.mark.parametrize("out_features,in_features,G", [(6, 8, 2), (8, 16, 4)]) +def test_fp8_per_group_scale_inv_hw(fp8_dtype, out_features, in_features, G): + assert in_features % G == 0 + dev = _device() + torch.cuda.set_device(dev.index or 0) + lin = _mk_linear(out_features, in_features, device=dev) + + torch.manual_seed(2) + base = torch.randn(out_features, in_features, device=dev, dtype=torch.float16) + + scale = (0.2 + 0.05 * torch.arange(G, device=dev, dtype=torch.float16)).expand(out_features, G).clone() + for o in range(out_features): + scale[o] += (o % 3) * 0.03 + + w_fp8 = _mk_fp8_weight_from_base(base, scale, fp8_dtype) + + with torch.no_grad(): + lin.weight = nn.Parameter(w_fp8) + lin.weight_scale_inv = (1.0 / scale).to(torch.float16) + + g = GPTQ(lin) + W = g._clone_module(device=dev) + + expected = _reference_dequant_from_fp8_param( + w_fp8, fp8_dtype=fp8_dtype, scale=lin.weight_scale_inv, is_inverse=True + ) + _assert_eq_close(W, expected, "per-group inverse scale dequant mismatch") + + +@pytest.mark.parametrize("fp8_dtype", [_pick_fp8_dtype_prefer_e4m3(), *_available_fp8_dtypes()]) +@pytest.mark.parametrize("out_features,in_features", [(6, 8), (8, 16)]) +def test_fp8_scale_non_inverse_path_hw(fp8_dtype, out_features, in_features): + dev = _device() + torch.cuda.set_device(dev.index or 0) + lin = _mk_linear(out_features, in_features, device=dev) + + torch.manual_seed(3) + base = torch.randn(out_features, in_features, device=dev, dtype=torch.float16) + + scale = torch.linspace(0.3, 0.7, out_features, device=dev, dtype=torch.float16) + w_fp8 = _mk_fp8_weight_from_base(base, scale, fp8_dtype) + + with torch.no_grad(): + lin.weight = nn.Parameter(w_fp8) + lin.weight_scale = scale # direct scale + + g = GPTQ(lin) + W = g._clone_module(device=dev) + + expected = _reference_dequant_from_fp8_param( + w_fp8, fp8_dtype=fp8_dtype, scale=lin.weight_scale, is_inverse=False + ) + _assert_eq_close(W, expected, "non-inverse scale dequant mismatch") + + +@pytest.mark.parametrize("fp8_dtype", [_pick_fp8_dtype_prefer_e4m3(), *_available_fp8_dtypes()]) +@pytest.mark.parametrize("out_features,in_features", [(6, 8), (8, 16)]) +def test_fp8_no_scale_fallback_hw(fp8_dtype, out_features, in_features): + dev = _device() + torch.cuda.set_device(dev.index or 0) + lin = _mk_linear(out_features, in_features, device=dev) + + torch.manual_seed(4) + base = torch.randn(out_features, in_features, device=dev, dtype=torch.float16) + + # No scale attrs -> expected is just FP8->FP16 cast + w_fp8 = base.to(fp8_dtype) + + with torch.no_grad(): + lin.weight = nn.Parameter(w_fp8) + + g = GPTQ(lin) + W = g._clone_module(device=dev) + + expected = _reference_dequant_from_fp8_param( + w_fp8, fp8_dtype=fp8_dtype, scale=None, is_inverse=False + ) + _assert_eq_close(W, expected, "no-scale FP8 cast fallback mismatch")