diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 2d5b7e36c..9af8ab114 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -17,6 +17,7 @@ import threading import time +from concurrent.futures import as_completed from contextlib import nullcontext from typing import Dict, List, Optional, TYPE_CHECKING @@ -1241,17 +1242,41 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): finalize_futures_snapshot = list(finalize_futures) - def _drain_finalize_futures(futures, finalize_pb_local, finalize_count_local, progress_bar): + if finalize_futures_snapshot: + finalize_pb.title( + f"Submodule finalize 0/{finalize_count}" + ).subtitle("Waiting for completions...").draw() + + future_metadata = { + future: (module_label, process, layer_idx) + for future, _, module_label, process, layer_idx in finalize_futures_snapshot + } + + def _drain_finalize_futures( + futures, + finalize_pb_local, + finalize_count_local, + future_metadata_local, + ): + completed_local = 0 try: - for future, idx, module_label, process, layer_idx in futures: + for future in as_completed(futures): + module_label, process, layer_idx = future_metadata_local.get( + future, (None, None, None) + ) + future.result() layer_label = f"Layer {layer_idx}" if layer_idx is not None else "layer ?" display_module = module_label or "" - subtitle = f"{process.name()}: {display_module}" + processor_name = process.name() if process is not None else "" + subtitle = f"{processor_name}: {display_module}" + + completed_local += 1 + finalize_pb_local.next() finalize_pb_local.title( - f"{layer_label} Finalize {idx}/{finalize_count_local}" - ).subtitle(subtitle).next().draw() + f"{layer_label} Finalize {completed_local}/{finalize_count_local}" + ).subtitle(subtitle).draw() finally: finalize_pb_local.close() @@ -1259,7 +1284,12 @@ def _drain_finalize_futures(futures, finalize_pb_local, finalize_count_local, pr # Drain finalize futures asynchronously so the main loop can continue scheduling work. threading.Thread( target=_drain_finalize_futures, - args=(finalize_futures_snapshot, finalize_pb, finalize_count, pb), + args=( + [future for future, *_ in finalize_futures_snapshot], + finalize_pb, + finalize_count, + future_metadata, + ), name="SubmoduleFinalizeWatcher", daemon=True, ).start() diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index 3e9f2af77..bb2de2755 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -20,7 +20,7 @@ class TorchQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [2, 3, 4, 8] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024] SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 0e2dbdb76..e92846de5 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -47,7 +47,7 @@ class TritonModuleMixin: class TritonV2QuantLinear(TorchQuantLinear, TritonModuleMixin): SUPPORTS_BITS = [2, 4, 8] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024] SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True @@ -207,4 +207,3 @@ def triton_xpu_available(): except Exception: return False - diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 1f62de7e1..0b4bb4dfc 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -293,10 +293,10 @@ def __post_init__(self): if key == "bits" and value not in fields_info[0].metadata["choices"]: raise ValueError(f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits.") elif key == "group_size" and value != -1 and value <= 0: - raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") + raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.") if self.group_size != -1 and self.group_size <= 0: - raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") + raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.") if not (0 < self.damp_percent < 1): raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.") diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 47262c067..a0d665463 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -164,16 +164,35 @@ def maybe_clear(obj: torch.nn.Module): return cleared +def _canonical_device(device: torch.device) -> torch.device: + """Return a canonical form so indexless accelerators collapse to device:0.""" + if device.type in {"cuda", "xpu", "npu"}: + index = device.index if device.index is not None else 0 + return torch.device(f"{device.type}:{index}") + return device + + def select_forward_devices(base_device: Optional[torch.device]) -> List[torch.device]: if base_device is None: return [CPU] - devices = [base_device] - base_type = base_device.type - if base_type in ("cuda", "xpu", "mps"): + devices: List[torch.device] = [] + seen: set[tuple[str, int | None]] = set() + + def _add(device: torch.device) -> None: + canonical = _canonical_device(device) + key = (canonical.type, canonical.index) + if key in seen: + return + seen.add(key) + devices.append(canonical) + + _add(base_device) + base_type = devices[0].type + if base_type in {"cuda", "xpu", "mps", "npu"}: for dev in ALL_DEVICES: - if dev.type == base_type and dev not in devices: - devices.append(dev) + if dev.type == base_type: + _add(dev) return devices @@ -181,10 +200,13 @@ def normalize_device_like(device_like) -> Optional[torch.device]: if device_like is None: return None if isinstance(device_like, torch.device): - return device_like - if hasattr(device_like, "to_torch_device"): - return device_like.to_torch_device() - return torch.device(str(device_like)) + device = device_like + elif hasattr(device_like, "to_torch_device"): + device = device_like.to_torch_device() + else: + device = torch.device(str(device_like)) + + return _canonical_device(device) def clone_module_for_devices( diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index a3c544bd5..a81271547 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -26,6 +26,9 @@ from .torch import CPU, META +_SMALL_MODULE_OFFLOAD_BYTES = 4 * 1024 # Skip disk writes for <4KB payloads + + # Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache def _fake_clear_device_cache(garbage_collection=False): pass @@ -74,6 +77,14 @@ def _prepare_offload_directory(target_dir: str) -> None: os.makedirs(target_dir, exist_ok=True) +def _tensor_nbytes(tensor: torch.Tensor) -> int: + try: + itemsize = tensor.element_size() + except RuntimeError: + itemsize = torch.empty((), dtype=tensor.dtype).element_size() + return tensor.numel() * itemsize + + def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: bundle_path = os.path.join(offload_dir, "module.safetensors") index: dict[str, dict] = {} @@ -177,6 +188,18 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): module_offload_dir = os.path.join(disk_path, name) + total_bytes = 0 + try: + state_items = module.state_dict().values() + except Exception: + state_items = [] + + for tensor in state_items: + total_bytes += _tensor_nbytes(tensor) + + if total_bytes <= _SMALL_MODULE_OFFLOAD_BYTES: + return + _prepare_offload_directory(module_offload_dir) _bundle_module_state_dict(module, module_offload_dir) diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 7541ca609..531b93995 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -18,9 +18,9 @@ class TestQwen3Moe(ModelTest): DEBUG = True ACT_GROUP_AWARE = True DESC_ACT = False - DATASET_SIZE = 1024 + DATASET_SIZE = 2048 DATASET_SORT = "desc" - QUANT_BATCH_SIZE = 1 + QUANT_BATCH_SIZE = 8 CALIB_NOISE_MODE = "unseen" CALIB_NOISE_PERCENT = 0.025