From 36e31de17605d0a2f10b54a7becbc9221ff60b7c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 10 Apr 2025 10:56:00 +0000 Subject: [PATCH 1/3] init tpu --- README.md | 21 +++++++++-------- gptqmodel/looper/gptq_processor.py | 8 +++---- gptqmodel/looper/module_looper.py | 2 +- gptqmodel/models/_const.py | 10 ++++++-- gptqmodel/models/auto.py | 5 ++++ gptqmodel/models/base.py | 4 ++-- gptqmodel/quantization/gptq.py | 22 +++++++++--------- gptqmodel/utils/importer.py | 6 +++-- gptqmodel/utils/model.py | 3 ++- gptqmodel/utils/torch.py | 37 ++++++++++++++++++++++++++++++ tests/models/model_test.py | 6 +++-- tests/models/test_llama3_2.py | 2 +- tests/models/test_qwen2_5.py | 2 +- 13 files changed, 91 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 8b653f900..7d49cd4a5 100644 --- a/README.md +++ b/README.md @@ -109,10 +109,10 @@ GPTQModel not only supports GPTQ but also QQQ with more quantization methods sup GPTQModel is an expandable/modular design supporting multiple quantization methods. -| Quantization | GPTQModel | Transformers | vLLM | SGLang | Lora Training | -|-------------------|---|---|---|---|---| -| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ | -| QQQ + Rotation | ✅ | x | ✅ | ✅ | x | +| Quantization | GPTQModel | Transformers | vLLM | SGLang | Lora Training | +|-------------------|---|---|------|--------|---| +| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ | +| QQQ + Rotation | ✅ | x | x | x | x | ## Multi-Modal @@ -163,12 +163,13 @@ Native support support some of the most popular multi-modal models: GPTQModel is validated for Linux, MacOS, and Windows 11: -| Platform | Device | | Optimized Arch | Kernels | -|-----------------|---------------| --- | -------------- |-------------------------------------------------------------| -| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch | -| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | IPEX, Torch | -| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch | -| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | IPEX, Torch | +| Platform | Device | | Optimized Arch | Kernels | +|-----------------|-----------------------| --- |-------------------------|----------------------------------------| +| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch | +| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | IPEX, Torch | +| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch | +| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | IPEX, Torch | +| 🐧 Linux | Google TPU | ✅ | `v5+` | Torch | | 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion | | 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch | diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 46b8a8b52..a1769e1a7 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -27,7 +27,7 @@ PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) from ..quantization import GPTQ from ..quantization.config import QUANT_METHOD, QuantizeConfig -from ..quantization.gptq import CPU, CUDA_0, CUDA_1 +from ..quantization.gptq import CPU, DEVICE_0, DEVICE_1 from ..utils.logger import setup_logger from ..utils.model import move_to, pack_model from ..utils.torch import torch_empty_cache, torch_sync @@ -151,12 +151,12 @@ def process(self, module: NamedModule, auto_gc: bool = True): self.avg_losses.append(avg_loss) self.module_names.append(f"layer-{module.layer_index}-{module.name}") - stats_0 = torch.cuda.memory_stats(CUDA_0) + stats_0 = torch.cuda.memory_stats(DEVICE_0) active_0 = stats_0.get("active_bytes.all.current", 0) / 1024 ** 2 peak_active_0 = stats_0.get("active_bytes.all.peak", 0) / 1024 ** 2 if torch.cuda.device_count() > 1: - stats_1 = torch.cuda.memory_stats(CUDA_1) + stats_1 = torch.cuda.memory_stats(DEVICE_1) active_1 = stats_1.get("active_bytes.all.current", 0) / 1024 ** 2 peak_active_1 = stats_1.get("active_bytes.all.peak", 0) / 1024 ** 2 @@ -207,7 +207,7 @@ def process(self, module: NamedModule, auto_gc: bool = True): # module.weight.data = torch.empty(1,1) # hack to remove weight.data # if auto_gc: # torch_empty_cache() - wq = wq.to(device=CUDA_0) + wq = wq.to(device=DEVICE_0) # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") module.state.update({ diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 5617a3d6d..2ab99ffc9 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -28,7 +28,7 @@ from ..models import BaseGPTQModel from ..models._const import SUPPORTS_MODULE_TYPES from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear -from ..quantization.gptq import CPU, CUDA_0, CUDA_1 +from ..quantization.gptq import CPU, DEVICE_0, DEVICE_1 from ..utils.logger import setup_logger from ..utils.model import (find_modules, get_device, get_module, get_module_by_name_prefix, get_moe_layer_modules, move_to, nested_move_to) diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index 2114e47d7..c28a418c5 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -24,7 +24,7 @@ from ..utils import BACKEND from ..utils.rocm import IS_ROCM -from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU +from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU, HAS_XLA, auto_select_torch_device CPU = device("cpu") CUDA = device("cuda") @@ -45,6 +45,7 @@ class DEVICE(str, Enum): XPU = "xpu" # Intel GPU: Datacenter Max + Arc MPS = "mps" # MacOS GPU: Apple Silion/Metal) ROCM = "rocm" # AMD GPU: ROCm maps to fake cuda + XLA = "xla" # Google TPU: v5+ @classmethod # conversion method called for init when string is passed, i.e. Device("CUDA") @@ -54,7 +55,10 @@ def _missing_(cls, value): return super()._missing_(value) def to_device_map(self): - return {"": DEVICE.CUDA if self == DEVICE.ROCM else self} + if self == DEVICE.XLA: + return {"": auto_select_torch_device()} + else: + return {"": DEVICE.CUDA if self == DEVICE.ROCM else self} class PLATFORM(str, Enum): @@ -86,6 +90,8 @@ def normalize_device(type_value: str | DEVICE | int | torch.device) -> DEVICE: return DEVICE.CUDA elif HAS_XPU: return DEVICE.XPU + elif HAS_XLA: + return DEVICE.XLA elif HAS_MPS: return DEVICE.MPS else: diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index dd32fc48b..a83a2ffa9 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -37,6 +37,11 @@ if 'CUDA_VISIBLE_DEVICES' in os.environ and 'ROCR_VISIBLE_DEVICES' in os.environ: del os.environ['ROCR_VISIBLE_DEVICES'] +# Auto-FIX TPU requires special env var +if not os.environ.get("PJRT_DEVICE", None): + os.environ["PJRT_DEVICE"] = 'TPU' + #os.environ["XLA_SYNC_WAIT"] = '1' + import sys # noqa: E402 # TODO: waiting for pytorch implementgation of aten ops for MPS diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index dfd913644..ff5296447 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -368,7 +368,7 @@ def quantize( desc_act=self.quantize_config.desc_act, sym=self.quantize_config.sym, backend=backend, - device=DEVICE(self.quantize_config.device), + device=DEVICE(self.quantize_config.device) if isinstance(self.quantize_config.device, DEVICE) else self.quantize_config.device, pack=True, format=self.quantize_config.format, pack_dtype=self.quantize_config.pack_dtype, @@ -589,7 +589,7 @@ def quantize_old( desc_act=self.quantize_config.desc_act, sym=self.quantize_config.sym, backend=backend, - device=DEVICE(self.quantize_config.device), + device=DEVICE(self.quantize_config.device) if isinstance(self.quantize_config.device, DEVICE) else self.quantize_config.device, pack=True, format=self.quantize_config.format, pack_dtype=self.quantize_config.pack_dtype, diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index dc84e9a4f..21ef5cd44 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -34,7 +34,7 @@ from ..looper.named_module import NamedModule from ..quantization import QuantizeConfig from ..utils.logger import setup_logger -from ..utils.torch import torch_compile, torch_sync +from ..utils.torch import auto_select_torch_device, torch_compile, torch_sync from .quantizer import HF_OPTIMUM, Quantizer log = setup_logger() @@ -43,8 +43,8 @@ torch.backends.cudnn.allow_tf32 = False CPU = torch.device("cpu") -CUDA_0 = torch.device("cuda:0") -CUDA_1 = torch.device("cuda:1") if torch.cuda.device_count() > 1 else CUDA_0 +DEVICE_0 = auto_select_torch_device(index=0) +DEVICE_1 = auto_select_torch_device(index=1) lock = threading.Lock() @@ -232,15 +232,15 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor): self.fwd_counter += 1 if self.fwd_inputs_buffered: - if CUDA_0.index != CUDA_1.index: - self.fwd_inputs_buffered_data.append(inp.to(device=CUDA_1, non_blocking=True)) + if DEVICE_0.index != DEVICE_1.index: + self.fwd_inputs_buffered_data.append(inp.to(device=DEVICE_1, non_blocking=True)) else: self.fwd_inputs_buffered_data.append(inp.to(device=CPU)) else: self.process_batch(inp) def process_batch(self, inp: torch.Tensor): - inp = inp.to(device=CUDA_1, dtype=torch.float32) + inp = inp.to(device=DEVICE_1, dtype=torch.float32) # input reshaping if isinstance(self.module, (nn.Linear, transformers.Conv1D)): @@ -260,8 +260,8 @@ def process_batch(self, inp: torch.Tensor): if self.H is None: self.H = torch.zeros((self.columns, self.columns), - dtype=torch.float32, - device=CUDA_1) + dtype=torch.float32, + device=DEVICE_1) beta = self.nsamples / (self.nsamples + batch_token_size) alpha = 2.0 / (self.nsamples + batch_token_size) @@ -306,7 +306,7 @@ def hessian_inverse(self, H: torch.Tensor): damp = self.qcfg.damp_percent while 1 > damp > 0: try: - diag = torch.arange(self.columns, device=CUDA_1) + diag = torch.arange(self.columns, device=DEVICE_1) H[diag, diag] += damp * torch.mean(torch.diag(H)) with lock: @@ -371,7 +371,7 @@ def quantize( if self.module_copy is None: # log.info("copy W to cuda_1") - W = self._clone_module(device=CUDA_1) + W = self._clone_module(device=DEVICE_1) else: W = self.module_copy self.module_copy = None @@ -485,7 +485,7 @@ def quantize( else: Q = Q.type_as(self.module.weight.data) - Q = Q.to(device=CUDA_1) + Q = Q.to(device=DEVICE_1) if scale == []: scale.append(self.quantizer.scale) diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 4bee62999..0d7b22d1a 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -36,7 +36,7 @@ from ..utils.logger import setup_logger from . import BACKEND from .rocm import IS_ROCM -from .torch import HAS_CUDA, HAS_MPS, HAS_XPU +from .torch import HAS_CUDA, HAS_MPS, HAS_XPU, HAS_XLA, auto_select_torch_device message_logged = False log = setup_logger() @@ -95,7 +95,6 @@ def normalize_device_device_map(device: Optional[Union[str, torch.device]], devi normalized_device = DEVICE.ROCM return normalized_device - def auto_select_device(device: Optional[DEVICE], backend: Optional[BACKEND]) -> DEVICE: assert device is None or isinstance(device, DEVICE) assert backend is None or isinstance(backend, BACKEND) @@ -107,6 +106,9 @@ def auto_select_device(device: Optional[DEVICE], backend: Optional[BACKEND]) -> device = DEVICE.CUDA elif HAS_XPU: device = DEVICE.XPU + elif HAS_XLA: + # TODO: xla device is not part of torch but external + device = auto_select_torch_device() elif HAS_MPS: device = DEVICE.MPS else: diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 929273d47..1b84679a8 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -899,7 +899,8 @@ def auto_dtype(config: PretrainedConfig, device: DEVICE, quant_inference: bool = False) -> torch.dtype: - assert isinstance(device, DEVICE) + # TODO: XLA device is created externally by torch_xla, not torch + assert isinstance(device, (DEVICE, torch.device)) # TODO: both MPS and XPU are locked to float16 # XPU stack is missing bfloat16 (hardware supports it) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 102e28bde..be95745e9 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -24,6 +24,7 @@ HAS_CUDA = False HAS_XPU = False +HAS_XLA = False HAS_MPS = False HAS_MLX = False @@ -47,6 +48,13 @@ if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available(): HAS_MPS = True +try: + import torch_xla + import torch_xla.core.xla_model as xm + HAS_XLA = True +except Exception: + pass + # mlx check try: import mlx.core.metal @@ -92,6 +100,9 @@ def torch_sync(device: torch.device = None): torch.cuda.synchronize() if HAS_XPU: torch.xpu.synchronize() + if HAS_XLA: + import torch_xla.core.xla_model as xm + xm.mark_step() if HAS_MPS: torch.mps.synchronize() return @@ -130,3 +141,29 @@ def torch_empty_cache(device: torch.device = None, gc: bool = True): # mlx is detached from pytorch if HAS_MLX: mlx.core.clear_cache() + +def auto_select_torch_device(index: int = 0): + assert index >= 0, f"device index should be a positive integer: actual = `{index}`" + + if HAS_CUDA: + # defensive check + if index > 0 and torch.cuda.device_count() <= index : + index = 0 + device = torch.device(f"cuda:{index}") + elif HAS_XPU: + # defensive check + if index > 0 and torch.xpu.device_count() <= index: + index = 0 + device = torch.device(f"xpu:{index}") + elif HAS_XLA: + import torch_xla.core.xla_model as xm + + # For TPUs, we don't need to specify an index like with CUDA + # The XLA runtime handles device assignment + device = xm.xla_device(devkind="tpu") + elif HAS_MPS: + device = torch.device("mps") # mps has no index + else: + device = torch.device("cpu") # cpu has no index + + return device diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 29f27618d..1ccc96dd3 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -48,7 +48,7 @@ from gptqmodel.quantization.config import QuantizeConfig # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 -from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache, auto_select_torch_device # noqa: E402 from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 from packaging.version import Version # noqa: E402 from transformers import AutoProcessor, AutoTokenizer # noqa: E402 @@ -142,7 +142,8 @@ def load_tokenizer(self, model_id_or_path, trust_remote_code=False): @classmethod def load_dataset(self, tokenizer=None, rows: int = DATASET_SIZE): - traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") + traindata = load_dataset("allenai/c4", data_files="en/c4-train.00000-of-01024.json.gz", split="train") + if not tokenizer: return traindata.select(range(rows)) @@ -186,6 +187,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut torch_dtype=torch_dtype, backend=self.LOAD_BACKEND, device_map={"": "cpu"} if self.LOAD_BACKEND == BACKEND.IPEX else "auto", + #device_map={"": auto_select_torch_device}, **args, ) diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 827be24b6..e3ed5181f 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -18,7 +18,7 @@ class TestLlama3_2(ModelTest): - NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" + NATIVE_MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # "/monster/data/model/Llama-3.2-1B-Instruct" # " NATIVE_ARC_CHALLENGE_ACC = 0.3567 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index 51eeba94d..52b72af84 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -18,7 +18,7 @@ class TestQwen2_5(ModelTest): - NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" + NATIVE_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 NATIVE_ARC_CHALLENGE_ACC = 0.2739 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055 From eb26cf090cc8b4cbec8cd832063074b89b21c1a2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 10 Apr 2025 10:56:12 +0000 Subject: [PATCH 2/3] init tpu Signed-off-by: Qubitium --- examples/tpu_smi.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 examples/tpu_smi.py diff --git a/examples/tpu_smi.py b/examples/tpu_smi.py new file mode 100644 index 000000000..399a9e5fb --- /dev/null +++ b/examples/tpu_smi.py @@ -0,0 +1,47 @@ +import time +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +from tabulate import tabulate + +def get_memory_table(): + mem_info = xm.get_memory_info(xm.xla_device()) + print(f"{mem_info}") + try: + used = int(mem_info['bytes_used']) + limit = int(mem_info['bytes_limit']) + peak = int(mem_info['peak_bytes_used']) + + data = [ + ["Used (MB)", used // (1024 * 1024)], + ["Peak (MB)", peak // (1024 * 1024)], + ["Limit (MB)", limit // (1024 * 1024)], + ] + return data + except Exception as e: + return [["Error", f"Failed to parse memory info: {e}"]] + +def get_metrics_summary(): + report = met.metrics_report().split('\n') + summary = [] + for line in report: + if ':' in line and any(key in line for key in ["Time", "Rate", "Size", "Count"]): + parts = line.strip().split(':') + summary.append([parts[0].strip(), parts[1].strip()]) + return summary + +def monitor_tpu(interval=1): + print("🧠 Starting TPU Monitor (updates every {}s)\n".format(interval)) + while True: + mem_table = get_memory_table() + metrics_table = get_metrics_summary() + + print("\n=== TPU Memory Usage ===") + print(tabulate(mem_table, headers=["Metric", "Value"], tablefmt="grid")) + + print("\n=== TPU Execution Metrics ===") + print(tabulate(metrics_table, headers=["Metric", "Value"], tablefmt="grid")) + + time.sleep(interval) + +if __name__ == "__main__": + monitor_tpu(interval=1) From 289806d3dd2f9a016f6af39800745f7db2ab83e3 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 10 Apr 2025 14:15:14 +0000 Subject: [PATCH 3/3] xla does not allow module.weight.data set. Need to use copy_ Signed-off-by: Qubitium --- gptqmodel/looper/gptq_processor.py | 7 +++++-- gptqmodel/looper/module_looper.py | 16 +++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index a1769e1a7..d725b4d58 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -30,7 +30,7 @@ from ..quantization.gptq import CPU, DEVICE_0, DEVICE_1 from ..utils.logger import setup_logger from ..utils.model import move_to, pack_model -from ..utils.torch import torch_empty_cache, torch_sync +from ..utils.torch import torch_empty_cache, torch_sync, HAS_XLA log = setup_logger() @@ -214,7 +214,10 @@ def process(self, module: NamedModule, auto_gc: bool = True): "wq": wq, # fp16, quantized weight but not int4 (packed qweight) }) - module.weight.data = wq + if HAS_XLA: + module.weight.data.copy_(wq) + else: + module.weight.data = wq if auto_gc: torch_empty_cache() diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 2ab99ffc9..dc9791020 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -32,7 +32,7 @@ from ..utils.logger import setup_logger from ..utils.model import (find_modules, get_device, get_module, get_module_by_name_prefix, get_moe_layer_modules, move_to, nested_move_to) -from ..utils.torch import torch_empty_cache +from ..utils.torch import torch_empty_cache, HAS_XLA log = setup_logger() @@ -82,8 +82,18 @@ def store_input_hook(_, args, kwargs): raise ValueError + def get_hw_device(): + if HAS_XLA: + import torch_xla.core.xla_model as xm + + # For TPUs, we don't need to specify an index like with CUDA + # The XLA runtime handles device assignment + return xm.xla_device(devkind="tpu") + else: + return self.gptq_model.quantize_config.device + # move layer to target device - layers[0] = layers[0].to(self.gptq_model.quantize_config.device) + layers[0] = layers[0].to(get_hw_device()) ori_outside_layer_module_devices = {} for module_name in self.gptq_model.base_modules: module = get_module_by_name_prefix(self.gptq_model.model, module_name) @@ -100,7 +110,7 @@ def store_input_hook(_, args, kwargs): self.gptq_model.pre_quantize_generate_hook_start() for example in calibration_data: for k, v in example.items(): - data_device = self.gptq_model.quantize_config.device if k == "pixel_values" else cur_layer_device + data_device = get_hw_device() if k == "pixel_values" else cur_layer_device if isinstance(v, list): for index in range(len(v)): if len(v[index].shape) == 1: