Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ GPTQModel not only supports GPTQ but also QQQ with more quantization methods sup

GPTQModel is an expandable/modular design supporting multiple quantization methods.

| Quantization | GPTQModel | Transformers | vLLM | SGLang | Lora Training |
|-------------------|---|---|---|---|---|
| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ |
| QQQ + Rotation | ✅ | x | ✅ | ✅ | x |
| Quantization | GPTQModel | Transformers | vLLM | SGLang | Lora Training |
|-------------------|---|---|------|--------|---|
| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ |
| QQQ + Rotation | ✅ | x | x | x | x |

## Multi-Modal

Expand Down Expand Up @@ -163,12 +163,13 @@ Native support support some of the most popular multi-modal models:

GPTQModel is validated for Linux, MacOS, and Windows 11:

| Platform | Device | | Optimized Arch | Kernels |
|-----------------|---------------| --- | -------------- |-------------------------------------------------------------|
| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch |
| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | IPEX, Torch |
| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | IPEX, Torch |
| Platform | Device | | Optimized Arch | Kernels |
|-----------------|-----------------------| --- |-------------------------|----------------------------------------|
| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch |
| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | IPEX, Torch |
| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | IPEX, Torch |
| 🐧 Linux | Google TPU | ✅ | `v5+` | Torch |
| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion |
| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch |

Expand Down
47 changes: 47 additions & 0 deletions examples/tpu_smi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import time
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from tabulate import tabulate

def get_memory_table():
mem_info = xm.get_memory_info(xm.xla_device())
print(f"{mem_info}")
try:
used = int(mem_info['bytes_used'])
limit = int(mem_info['bytes_limit'])
peak = int(mem_info['peak_bytes_used'])

data = [
["Used (MB)", used // (1024 * 1024)],
["Peak (MB)", peak // (1024 * 1024)],
["Limit (MB)", limit // (1024 * 1024)],
]
return data
except Exception as e:
return [["Error", f"Failed to parse memory info: {e}"]]

def get_metrics_summary():
report = met.metrics_report().split('\n')
summary = []
for line in report:
if ':' in line and any(key in line for key in ["Time", "Rate", "Size", "Count"]):
parts = line.strip().split(':')
summary.append([parts[0].strip(), parts[1].strip()])
return summary

def monitor_tpu(interval=1):
print("🧠 Starting TPU Monitor (updates every {}s)\n".format(interval))
while True:
mem_table = get_memory_table()
metrics_table = get_metrics_summary()

print("\n=== TPU Memory Usage ===")
print(tabulate(mem_table, headers=["Metric", "Value"], tablefmt="grid"))

print("\n=== TPU Execution Metrics ===")
print(tabulate(metrics_table, headers=["Metric", "Value"], tablefmt="grid"))

time.sleep(interval)

if __name__ == "__main__":
monitor_tpu(interval=1)
15 changes: 9 additions & 6 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
from ..quantization import GPTQ
from ..quantization.config import QUANT_METHOD, QuantizeConfig
from ..quantization.gptq import CPU, CUDA_0, CUDA_1
from ..quantization.gptq import CPU, DEVICE_0, DEVICE_1
from ..utils.logger import setup_logger
from ..utils.model import move_to, pack_model
from ..utils.torch import torch_empty_cache, torch_sync
from ..utils.torch import torch_empty_cache, torch_sync, HAS_XLA

log = setup_logger()

Expand Down Expand Up @@ -151,12 +151,12 @@ def process(self, module: NamedModule, auto_gc: bool = True):
self.avg_losses.append(avg_loss)
self.module_names.append(f"layer-{module.layer_index}-{module.name}")

stats_0 = torch.cuda.memory_stats(CUDA_0)
stats_0 = torch.cuda.memory_stats(DEVICE_0)
active_0 = stats_0.get("active_bytes.all.current", 0) / 1024 ** 2
peak_active_0 = stats_0.get("active_bytes.all.peak", 0) / 1024 ** 2

if torch.cuda.device_count() > 1:
stats_1 = torch.cuda.memory_stats(CUDA_1)
stats_1 = torch.cuda.memory_stats(DEVICE_1)
active_1 = stats_1.get("active_bytes.all.current", 0) / 1024 ** 2
peak_active_1 = stats_1.get("active_bytes.all.peak", 0) / 1024 ** 2

Expand Down Expand Up @@ -207,14 +207,17 @@ def process(self, module: NamedModule, auto_gc: bool = True):
# module.weight.data = torch.empty(1,1) # hack to remove weight.data
# if auto_gc:
# torch_empty_cache()
wq = wq.to(device=CUDA_0)
wq = wq.to(device=DEVICE_0)

# logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}")
module.state.update({
"wq": wq, # fp16, quantized weight but not int4 (packed qweight)
})

module.weight.data = wq
if HAS_XLA:
module.weight.data.copy_(wq)
else:
module.weight.data = wq

if auto_gc:
torch_empty_cache()
Expand Down
18 changes: 14 additions & 4 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
from ..models import BaseGPTQModel
from ..models._const import SUPPORTS_MODULE_TYPES
from ..nn_modules.hooked_linear import replace_module_with_hooked_legacy, replace_module_with_hooked_tree
from ..quantization.gptq import CPU, CUDA_0, CUDA_1
from ..quantization.gptq import CPU, DEVICE_0, DEVICE_1
from ..utils.logger import setup_logger
from ..utils.model import (find_modules, get_device, get_module, get_module_by_name_prefix,
get_moe_layer_modules, move_to, nested_move_to)
from ..utils.torch import torch_empty_cache
from ..utils.torch import torch_empty_cache, HAS_XLA

log = setup_logger()

Expand Down Expand Up @@ -82,8 +82,18 @@ def store_input_hook(_, args, kwargs):

raise ValueError

def get_hw_device():
if HAS_XLA:
import torch_xla.core.xla_model as xm

# For TPUs, we don't need to specify an index like with CUDA
# The XLA runtime handles device assignment
return xm.xla_device(devkind="tpu")
else:
return self.gptq_model.quantize_config.device

# move layer to target device
layers[0] = layers[0].to(self.gptq_model.quantize_config.device)
layers[0] = layers[0].to(get_hw_device())
ori_outside_layer_module_devices = {}
for module_name in self.gptq_model.base_modules:
module, _ = get_module_by_name_prefix(self.gptq_model.model, [module_name])
Expand All @@ -100,7 +110,7 @@ def store_input_hook(_, args, kwargs):
self.gptq_model.pre_quantize_generate_hook_start()
for example in calibration_data:
for k, v in example.items():
data_device = self.gptq_model.quantize_config.device if k == "pixel_values" else cur_layer_device
data_device = get_hw_device() if k == "pixel_values" else cur_layer_device
if isinstance(v, list):
for index in range(len(v)):
if len(v[index].shape) == 1:
Expand Down
10 changes: 8 additions & 2 deletions gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..utils import BACKEND
from ..utils.rocm import IS_ROCM
from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU
from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU, HAS_XLA, auto_select_torch_device

CPU = device("cpu")
CUDA = device("cuda")
Expand All @@ -44,6 +44,7 @@ class DEVICE(str, Enum):
XPU = "xpu" # Intel GPU: Datacenter Max + Arc
MPS = "mps" # MacOS GPU: Apple Silion/Metal)
ROCM = "rocm" # AMD GPU: ROCm maps to fake cuda
XLA = "xla" # Google TPU: v5+

@classmethod
# conversion method called for init when string is passed, i.e. Device("CUDA")
Expand All @@ -53,7 +54,10 @@ def _missing_(cls, value):
return super()._missing_(value)

def to_device_map(self):
return {"": DEVICE.CUDA if self == DEVICE.ROCM else self}
if self == DEVICE.XLA:
return {"": auto_select_torch_device()}
else:
return {"": DEVICE.CUDA if self == DEVICE.ROCM else self}


class PLATFORM(str, Enum):
Expand Down Expand Up @@ -85,6 +89,8 @@ def normalize_device(type_value: str | DEVICE | int | torch.device) -> DEVICE:
return DEVICE.CUDA
elif HAS_XPU:
return DEVICE.XPU
elif HAS_XLA:
return DEVICE.XLA
elif HAS_MPS:
return DEVICE.MPS
else:
Expand Down
5 changes: 5 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
if 'CUDA_VISIBLE_DEVICES' in os.environ and 'ROCR_VISIBLE_DEVICES' in os.environ:
del os.environ['ROCR_VISIBLE_DEVICES']

# Auto-FIX TPU requires special env var
if not os.environ.get("PJRT_DEVICE", None):
os.environ["PJRT_DEVICE"] = 'TPU'
#os.environ["XLA_SYNC_WAIT"] = '1'

import sys # noqa: E402

# TODO: waiting for pytorch implementgation of aten ops for MPS
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def quantize(
desc_act=self.quantize_config.desc_act,
sym=self.quantize_config.sym,
backend=backend,
device=DEVICE(self.quantize_config.device),
device=DEVICE(self.quantize_config.device) if isinstance(self.quantize_config.device, DEVICE) else self.quantize_config.device,
pack=True,
format=self.quantize_config.format,
pack_dtype=self.quantize_config.pack_dtype,
Expand Down Expand Up @@ -591,7 +591,7 @@ def quantize_old(
desc_act=self.quantize_config.desc_act,
sym=self.quantize_config.sym,
backend=backend,
device=DEVICE(self.quantize_config.device),
device=DEVICE(self.quantize_config.device) if isinstance(self.quantize_config.device, DEVICE) else self.quantize_config.device,
pack=True,
format=self.quantize_config.format,
pack_dtype=self.quantize_config.pack_dtype,
Expand Down
22 changes: 11 additions & 11 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..looper.named_module import NamedModule
from ..quantization import QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.torch import torch_compile, torch_sync
from ..utils.torch import auto_select_torch_device, torch_compile, torch_sync
from .quantizer import HF_OPTIMUM, Quantizer

log = setup_logger()
Expand All @@ -42,8 +42,8 @@
torch.backends.cudnn.allow_tf32 = False

CPU = torch.device("cpu")
CUDA_0 = torch.device("cuda:0")
CUDA_1 = torch.device("cuda:1") if torch.cuda.device_count() > 1 else CUDA_0
DEVICE_0 = auto_select_torch_device(index=0)
DEVICE_1 = auto_select_torch_device(index=1)

lock = threading.Lock()

Expand Down Expand Up @@ -231,15 +231,15 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
self.fwd_counter += 1

if self.fwd_inputs_buffered:
if CUDA_0.index != CUDA_1.index:
self.fwd_inputs_buffered_data.append(inp.to(device=CUDA_1, non_blocking=True))
if DEVICE_0.index != DEVICE_1.index:
self.fwd_inputs_buffered_data.append(inp.to(device=DEVICE_1, non_blocking=True))
else:
self.fwd_inputs_buffered_data.append(inp.to(device=CPU))
else:
self.process_batch(inp)

def process_batch(self, inp: torch.Tensor):
inp = inp.to(device=CUDA_1, dtype=torch.float32)
inp = inp.to(device=DEVICE_1, dtype=torch.float32)

# input reshaping
if isinstance(self.module, (nn.Linear, transformers.Conv1D)):
Expand All @@ -259,8 +259,8 @@ def process_batch(self, inp: torch.Tensor):

if self.H is None:
self.H = torch.zeros((self.columns, self.columns),
dtype=torch.float32,
device=CUDA_1)
dtype=torch.float32,
device=DEVICE_1)

beta = self.nsamples / (self.nsamples + batch_token_size)
alpha = 2.0 / (self.nsamples + batch_token_size)
Expand Down Expand Up @@ -305,7 +305,7 @@ def hessian_inverse(self, H: torch.Tensor):
damp = self.qcfg.damp_percent
while 1 > damp > 0:
try:
diag = torch.arange(self.columns, device=CUDA_1)
diag = torch.arange(self.columns, device=DEVICE_1)
H[diag, diag] += damp * torch.mean(torch.diag(H))

with lock:
Expand Down Expand Up @@ -370,7 +370,7 @@ def quantize(

if self.module_copy is None:
# log.info("copy W to cuda_1")
W = self._clone_module(device=CUDA_1)
W = self._clone_module(device=DEVICE_1)
else:
W = self.module_copy
self.module_copy = None
Expand Down Expand Up @@ -484,7 +484,7 @@ def quantize(
else:
Q = Q.type_as(self.module.weight.data)

Q = Q.to(device=CUDA_1)
Q = Q.to(device=DEVICE_1)

if scale == []:
scale.append(self.quantizer.scale)
Expand Down
6 changes: 4 additions & 2 deletions gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ..utils.logger import setup_logger
from . import BACKEND
from .rocm import IS_ROCM
from .torch import HAS_CUDA, HAS_MPS, HAS_XPU
from .torch import HAS_CUDA, HAS_MPS, HAS_XPU, HAS_XLA, auto_select_torch_device

message_logged = False
log = setup_logger()
Expand Down Expand Up @@ -95,7 +95,6 @@ def normalize_device_device_map(device: Optional[Union[str, torch.device]], devi
normalized_device = DEVICE.ROCM
return normalized_device


def auto_select_device(device: Optional[DEVICE], backend: Optional[BACKEND]) -> DEVICE:
assert device is None or isinstance(device, DEVICE)
assert backend is None or isinstance(backend, BACKEND)
Expand All @@ -107,6 +106,9 @@ def auto_select_device(device: Optional[DEVICE], backend: Optional[BACKEND]) ->
device = DEVICE.CUDA
elif HAS_XPU:
device = DEVICE.XPU
elif HAS_XLA:
# TODO: xla device is not part of torch but external
device = auto_select_torch_device()
elif HAS_MPS:
device = DEVICE.MPS
else:
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,8 @@ def auto_dtype(config: PretrainedConfig,
device: DEVICE,
quant_inference: bool = False) -> torch.dtype:

assert isinstance(device, DEVICE)
# TODO: XLA device is created externally by torch_xla, not torch
assert isinstance(device, (DEVICE, torch.device))

# TODO: both MPS and XPU are locked to float16
# XPU stack is missing bfloat16 (hardware supports it)
Expand Down
Loading