Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import threading
import time
from concurrent.futures import as_completed
from contextlib import nullcontext
from typing import Dict, List, Optional, TYPE_CHECKING

Expand Down Expand Up @@ -1241,25 +1242,54 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx):

finalize_futures_snapshot = list(finalize_futures)

def _drain_finalize_futures(futures, finalize_pb_local, finalize_count_local, progress_bar):
if finalize_futures_snapshot:
finalize_pb.title(
f"Submodule finalize 0/{finalize_count}"
).subtitle("Waiting for completions...").draw()

future_metadata = {
future: (module_label, process, layer_idx)
for future, _, module_label, process, layer_idx in finalize_futures_snapshot
}

def _drain_finalize_futures(
futures,
finalize_pb_local,
finalize_count_local,
future_metadata_local,
):
completed_local = 0
try:
for future, idx, module_label, process, layer_idx in futures:
for future in as_completed(futures):
module_label, process, layer_idx = future_metadata_local.get(
future, (None, None, None)
)

future.result()

layer_label = f"Layer {layer_idx}" if layer_idx is not None else "layer ?"
display_module = module_label or "<unnamed>"
subtitle = f"{process.name()}: {display_module}"
processor_name = process.name() if process is not None else "<processor>"
subtitle = f"{processor_name}: {display_module}"

completed_local += 1
finalize_pb_local.next()
finalize_pb_local.title(
f"{layer_label} Finalize {idx}/{finalize_count_local}"
).subtitle(subtitle).next().draw()
f"{layer_label} Finalize {completed_local}/{finalize_count_local}"
).subtitle(subtitle).draw()
finally:
finalize_pb_local.close()

if finalize_futures_snapshot:
# Drain finalize futures asynchronously so the main loop can continue scheduling work.
threading.Thread(
target=_drain_finalize_futures,
args=(finalize_futures_snapshot, finalize_pb, finalize_count, pb),
args=(
[future for future, *_ in finalize_futures_snapshot],
finalize_pb,
finalize_count,
future_metadata,
),
name="SubmoduleFinalizeWatcher",
daemon=True,
).start()
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class TorchQuantLinear(PackableQuantLinear):
SUPPORTS_BITS = [2, 3, 4, 8]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024]
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/nn_modules/qlinear/tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TritonModuleMixin:

class TritonV2QuantLinear(TorchQuantLinear, TritonModuleMixin):
SUPPORTS_BITS = [2, 4, 8]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024]
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
Expand Down Expand Up @@ -207,4 +207,3 @@ def triton_xpu_available():
except Exception:
return False


4 changes: 2 additions & 2 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ def __post_init__(self):
if key == "bits" and value not in fields_info[0].metadata["choices"]:
raise ValueError(f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits.")
elif key == "group_size" and value != -1 and value <= 0:
raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.")
raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.")

if self.group_size != -1 and self.group_size <= 0:
raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.")
raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.")

if not (0 < self.damp_percent < 1):
raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.")
Expand Down
40 changes: 31 additions & 9 deletions gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,27 +164,49 @@ def maybe_clear(obj: torch.nn.Module):
return cleared


def _canonical_device(device: torch.device) -> torch.device:
"""Return a canonical form so indexless accelerators collapse to device:0."""
if device.type in {"cuda", "xpu", "npu"}:
index = device.index if device.index is not None else 0
return torch.device(f"{device.type}:{index}")
return device


def select_forward_devices(base_device: Optional[torch.device]) -> List[torch.device]:
if base_device is None:
return [CPU]

devices = [base_device]
base_type = base_device.type
if base_type in ("cuda", "xpu", "mps"):
devices: List[torch.device] = []
seen: set[tuple[str, int | None]] = set()

def _add(device: torch.device) -> None:
canonical = _canonical_device(device)
key = (canonical.type, canonical.index)
if key in seen:
return
seen.add(key)
devices.append(canonical)

_add(base_device)
base_type = devices[0].type
if base_type in {"cuda", "xpu", "mps", "npu"}:
for dev in ALL_DEVICES:
if dev.type == base_type and dev not in devices:
devices.append(dev)
if dev.type == base_type:
_add(dev)
return devices


def normalize_device_like(device_like) -> Optional[torch.device]:
if device_like is None:
return None
if isinstance(device_like, torch.device):
return device_like
if hasattr(device_like, "to_torch_device"):
return device_like.to_torch_device()
return torch.device(str(device_like))
device = device_like
elif hasattr(device_like, "to_torch_device"):
device = device_like.to_torch_device()
else:
device = torch.device(str(device_like))

return _canonical_device(device)


def clone_module_for_devices(
Expand Down
23 changes: 23 additions & 0 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .torch import CPU, META


_SMALL_MODULE_OFFLOAD_BYTES = 4 * 1024 # Skip disk writes for <4KB payloads


# Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache
def _fake_clear_device_cache(garbage_collection=False):
pass
Expand Down Expand Up @@ -74,6 +77,14 @@ def _prepare_offload_directory(target_dir: str) -> None:
os.makedirs(target_dir, exist_ok=True)


def _tensor_nbytes(tensor: torch.Tensor) -> int:
try:
itemsize = tensor.element_size()
except RuntimeError:
itemsize = torch.empty((), dtype=tensor.dtype).element_size()
return tensor.numel() * itemsize


def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict:
bundle_path = os.path.join(offload_dir, "module.safetensors")
index: dict[str, dict] = {}
Expand Down Expand Up @@ -177,6 +188,18 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."):

module_offload_dir = os.path.join(disk_path, name)

total_bytes = 0
try:
state_items = module.state_dict().values()
except Exception:
state_items = []

for tensor in state_items:
total_bytes += _tensor_nbytes(tensor)

if total_bytes <= _SMALL_MODULE_OFFLOAD_BYTES:
return

_prepare_offload_directory(module_offload_dir)
_bundle_module_state_dict(module, module_offload_dir)

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class TestQwen3Moe(ModelTest):
DEBUG = True
ACT_GROUP_AWARE = True
DESC_ACT = False
DATASET_SIZE = 1024
DATASET_SIZE = 2048
DATASET_SORT = "desc"
QUANT_BATCH_SIZE = 1
QUANT_BATCH_SIZE = 8
CALIB_NOISE_MODE = "unseen"
CALIB_NOISE_PERCENT = 0.025

Expand Down