diff --git a/benchmark/benchmark_diffusers.py b/benchmark/benchmark_diffusers.py index 6017939..742b4c3 100644 --- a/benchmark/benchmark_diffusers.py +++ b/benchmark/benchmark_diffusers.py @@ -162,7 +162,8 @@ def benchmark_time(pipeline, prompt, number=3, repeat=3, **kwargs): def benchmark(pipeline, prompt, number=1, repeat=5, include_memory=True, **kwargs): - with GPUMemoryMonitor() as gpu_monitor: + device = pipeline.device + with GPUMemoryMonitor(device) as gpu_monitor: if not include_memory: gpu_monitor.stop() t = benchmark_time(pipeline, prompt, number=number, repeat=repeat, **kwargs) diff --git a/benchmark/benchmark_llm.py b/benchmark/benchmark_llm.py index eaaf35e..091f23b 100644 --- a/benchmark/benchmark_llm.py +++ b/benchmark/benchmark_llm.py @@ -186,7 +186,8 @@ def benchmark( if isinstance(prompt, list) and "batch_size" not in generate_kwargs: generate_kwargs["batch_size"] = len(prompt) - with GPUMemoryMonitor() as gpu_monitor: + device = generator.model.device + with GPUMemoryMonitor(device) as gpu_monitor: if not include_memory: gpu_monitor.stop() t, f = benchmark_time( diff --git a/benchmark/benchmark_musicgen.py b/benchmark/benchmark_musicgen.py index ecf1ca3..e496bbb 100644 --- a/benchmark/benchmark_musicgen.py +++ b/benchmark/benchmark_musicgen.py @@ -155,7 +155,8 @@ def benchmark( include_memory=True, **generate_kwargs, ): - with GPUMemoryMonitor() as gpu_monitor: + device = generator.model.device + with GPUMemoryMonitor(device) as gpu_monitor: if not include_memory: gpu_monitor.stop() t, f = benchmark_time( diff --git a/benchmark/benchmark_whisper.py b/benchmark/benchmark_whisper.py index 2745e23..8f3eff5 100644 --- a/benchmark/benchmark_whisper.py +++ b/benchmark/benchmark_whisper.py @@ -142,7 +142,8 @@ def benchmark( include_memory=True, **generate_kwargs, ): - with GPUMemoryMonitor() as gpu_monitor: + device = generator.model.device + with GPUMemoryMonitor(device) as gpu_monitor: if not include_memory: gpu_monitor.stop() t, f = benchmark_time( diff --git a/benchmark/gpu_monitor.py b/benchmark/gpu_monitor.py index 6726f94..cc3f951 100644 --- a/benchmark/gpu_monitor.py +++ b/benchmark/gpu_monitor.py @@ -1,13 +1,74 @@ import multiprocessing +from typing import Optional import pynvml +import torch +from logger import _LOGGER_MAIN +logger = _LOGGER_MAIN -def monitor_gpu_memory(queue, running_flag): + +def _device_to_uuid_cuda(cuda_idx: Optional[int] = None) -> str: + import uuid as _uuid + from cuda.bindings import driver as cuda # Use the low-level driver API + + def CUASSERT(cuda_ret): + err = cuda_ret[0] + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + CUASSERT(cuda.cuInit(0)) + cuda_idx = cuda_idx if cuda_idx is not None else torch.cuda.current_device() + dev = CUASSERT(cuda.cuDeviceGet(cuda_idx)) + uuid_struct = CUASSERT(cuda.cuDeviceGetUuid(dev)) + uuid_str = str(_uuid.UUID(bytes=bytes(uuid_struct.bytes))) + return uuid_str + + +def _device_to_uuid_torch(cuda_idx: Optional[int] = None): + device_properties = torch.cuda.get_device_properties(cuda_idx) + uuid_str = str(device_properties.uuid) + return uuid_str + + +def get_nvml_id_by_cuda_uuid(cuda_id): + errors = [] + func_list = [ + (_device_to_uuid_torch, "torch.cuda"), + (_device_to_uuid_cuda, "cuda-python"), + ] + for get_uuid_func, name in func_list: + try: + uuid_str = get_uuid_func(cuda_id) + break + except Exception as e: + errors.append(f"Using {name}: {e}") + else: + raise RuntimeError(f"Failed to get UUID for device {cuda_id}: {errors}") + + pynvml.nvmlInit() + encoded = uuid_str.replace("-", "").encode("utf-8") + try: + # Get the NVML device handle using the UUID + handle = pynvml.nvmlDeviceGetHandleByUUID(encoded) + id = pynvml.nvmlDeviceGetIndex(handle) + return id + + except pynvml.NVMLError as e: + raise RuntimeError(f"PyNVML Error: {e}") + + finally: + pynvml.nvmlShutdown() + + +def monitor_gpu_memory(queue, running_flag, nvml_id): pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming GPU 0 max_memory_usage = 0 try: + handle = pynvml.nvmlDeviceGetHandleByIndex(nvml_id) while running_flag.value: info = pynvml.nvmlDeviceGetMemoryInfo(handle) memory_usage = info.used / (1024 * 1024) # Convert to MiB @@ -18,31 +79,50 @@ def monitor_gpu_memory(queue, running_flag): class GPUMemoryMonitor: - def __init__(self): + def __init__(self, device: Optional[int] = None): self.process = None self.queue = multiprocessing.Queue() self.running_flag = multiprocessing.Value("b", False) + self.device = torch.device( + device if device is not None else torch.cuda.current_device() + ) + + pynvml.nvmlInit() + try: + self.nvml_id = get_nvml_id_by_cuda_uuid(self.device.index) + logger.debug(f"Initialized GPU monitor for device {self.device} (NVML ID: {self.nvml_id})") + except Exception as e: + logger.warning(f"Error getting NVML device index: {e}. Using NVML ID 0.") + self.nvml_id = 0 + finally: + pynvml.nvmlShutdown() def start(self): if self.running_flag.value: - print("GPU monitor is already running.") + logger.warning("Cannot start GPU monitor: it is already running.") return self.running_flag.value = True self.process = multiprocessing.Process( - target=monitor_gpu_memory, args=(self.queue, self.running_flag) + target=monitor_gpu_memory, + args=(self.queue, self.running_flag, self.nvml_id), ) self.process.start() + logger.info(f"GPU monitor started for device {self.device} (NVML ID: {self.nvml_id})") def stop(self): if not self.running_flag.value: - print("GPU monitor is not running.") + logger.warning("Cannot stop GPU monitor: it is not running.") return self.running_flag.value = False self.process.join() + logger.info("GPU monitor stopped") def get_max_memory_usage(self): if not self.queue.empty(): - return self.queue.get() + max_usage = self.queue.get() + logger.debug(f"Retrieved max GPU memory usage: {max_usage:.2f} MiB") + return max_usage + logger.warning("No memory usage data available") return 0 def __enter__(self): @@ -61,4 +141,4 @@ def __exit__(self, exc_type, exc_value, traceback): monitor.stop() - print("Maximum GPU memory usage (MB):", monitor.get_max_memory_usage()) + logger.info(f"Maximum GPU memory usage (MB): {monitor.get_max_memory_usage()}")