From f461083fb7cf7c3ca1193c6eadc528ff575a60f6 Mon Sep 17 00:00:00 2001 From: xin3he Date: Thu, 25 Apr 2024 14:59:18 +0800 Subject: [PATCH] add xpu accelerator in 3x Signed-off-by: xin3he --- .../torch/utils/auto_accelerator.py | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/utils/auto_accelerator.py b/neural_compressor/torch/utils/auto_accelerator.py index 7e59f00e180..cbf7db2573c 100644 --- a/neural_compressor/torch/utils/auto_accelerator.py +++ b/neural_compressor/torch/utils/auto_accelerator.py @@ -35,8 +35,9 @@ htcore = LazyImport("habana_frameworks.torch.core") PRIORITY_HPU = 100 -PRIORITY_CUDA = 95 -PRIORITY_CPU = 90 +PRIORITY_XPU = 95 +PRIORITY_CUDA = 90 +PRIORITY_CPU = 80 class AcceleratorRegistry: @@ -213,8 +214,44 @@ def device(self, device_index=None): def empty_cache(self): return torch.cuda.empty_cache() - def mark_step(self): - pass + +@register_accelerator(name="xpu", priority=PRIORITY_XPU) +class XPU_Accelerator(Auto_Accelerator): + def __init__(self) -> None: + self._name = "xpu" + + def name(self) -> str: + return self._name + + @classmethod + def is_available(cls) -> bool: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return True + else: + return False + + def device_name(self, device_indx) -> str: + if device_indx is None: + return "xpu" + return f"xpu:{device_indx}" + + def synchronize(self): + return torch.xpu.synchronize() + + def set_device(self, device_index): + return torch.xpu.set_device(device_index) + + def current_device(self): + return torch.xpu.current_device() + + def current_device_name(self): + return "xpu:{}".format(torch.xpu.current_device()) + + def device(self, device_index=None): + return torch.xpu.device(device_index) + + def empty_cache(self): + return torch.xpu.empty_cache() @register_accelerator(name="hpu", priority=PRIORITY_HPU)