Skip to content

Commit 065fa89

Browse files
xin3hepre-commit-ci[bot]
authored andcommitted
auto detect available device when exporting (#1743)
Description auto detect available packing device when exporting, priority: ["xpu", "cuda", "cpu"] Expected Behavior & Potential Risk When XPU format is required and XPU device is not available, INC will select packing device from ["xpu", "cuda", "cpu"] automatically. How has this PR been tested? test on cuda device when using xpu argument. --------- Signed-off-by: He, Xin3 <[email protected]> Signed-off-by: y <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7ac0ebe commit 065fa89

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

neural_compressor/model/torch_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,18 @@ def export_compressed_model(
498498
gptq_config = self.gptq_config if hasattr(self, "gptq_config") else {}
499499

500500
autoround_config = self.autoround_config if hasattr(self, "autoround_config") else {}
501+
# check available device, priority: ["xpu", "cuda", "cpu"]
502+
availiable_device = []
503+
if hasattr(torch, "xpu") and torch.xpu.is_available():
504+
availiable_device.append("xpu")
505+
if torch.cuda.is_available():
506+
availiable_device.append("cuda")
507+
availiable_device.append("cpu")
508+
orig_device = device
509+
if device not in availiable_device and "cuda" not in device: # cuda in cuda:0
510+
logger.info(f"{device} is not detected in current environment, please check.")
511+
device = availiable_device[0]
512+
logger.info(f"The compression device has been changed to {device}.")
501513
if gptq_config:
502514
for k, v in weight_config.items():
503515
logger.debug(f"Compressing {k} on device {device}")
@@ -558,7 +570,7 @@ def export_compressed_model(
558570
new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm)
559571
set_module(self.model, k, new_module)
560572
elif autoround_config:
561-
if device == "xpu":
573+
if orig_device == "xpu":
562574
for k, v in weight_config.items():
563575
logger.debug(f"Compressing {k} on device {device}")
564576
if v["dtype"] == "fp32":

0 commit comments

Comments
 (0)