Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmark/benchmark_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def benchmark_time(pipeline, prompt, number=3, repeat=3, **kwargs):


def benchmark(pipeline, prompt, number=1, repeat=5, include_memory=True, **kwargs):
with GPUMemoryMonitor() as gpu_monitor:
device = pipeline.device
with GPUMemoryMonitor(device) as gpu_monitor:
if not include_memory:
gpu_monitor.stop()
t = benchmark_time(pipeline, prompt, number=number, repeat=repeat, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion benchmark/benchmark_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def benchmark(
if isinstance(prompt, list) and "batch_size" not in generate_kwargs:
generate_kwargs["batch_size"] = len(prompt)

with GPUMemoryMonitor() as gpu_monitor:
device = generator.model.device
with GPUMemoryMonitor(device) as gpu_monitor:
if not include_memory:
gpu_monitor.stop()
t, f = benchmark_time(
Expand Down
3 changes: 2 additions & 1 deletion benchmark/benchmark_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def benchmark(
include_memory=True,
**generate_kwargs,
):
with GPUMemoryMonitor() as gpu_monitor:
device = generator.model.device
with GPUMemoryMonitor(device) as gpu_monitor:
if not include_memory:
gpu_monitor.stop()
t, f = benchmark_time(
Expand Down
3 changes: 2 additions & 1 deletion benchmark/benchmark_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def benchmark(
include_memory=True,
**generate_kwargs,
):
with GPUMemoryMonitor() as gpu_monitor:
device = generator.model.device
with GPUMemoryMonitor(device) as gpu_monitor:
if not include_memory:
gpu_monitor.stop()
t, f = benchmark_time(
Expand Down
96 changes: 88 additions & 8 deletions benchmark/gpu_monitor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,74 @@
import multiprocessing
from typing import Optional
import pynvml
import torch
from logger import _LOGGER_MAIN

logger = _LOGGER_MAIN

def monitor_gpu_memory(queue, running_flag):

def _device_to_uuid_cuda(cuda_idx: Optional[int] = None) -> str:
import uuid as _uuid
from cuda.bindings import driver as cuda # Use the low-level driver API

def CUASSERT(cuda_ret):
err = cuda_ret[0]
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"CUDA ERROR: {err}")
if len(cuda_ret) > 1:
return cuda_ret[1]
return None

CUASSERT(cuda.cuInit(0))
cuda_idx = cuda_idx if cuda_idx is not None else torch.cuda.current_device()
dev = CUASSERT(cuda.cuDeviceGet(cuda_idx))
uuid_struct = CUASSERT(cuda.cuDeviceGetUuid(dev))
uuid_str = str(_uuid.UUID(bytes=bytes(uuid_struct.bytes)))
return uuid_str


def _device_to_uuid_torch(cuda_idx: Optional[int] = None):
device_properties = torch.cuda.get_device_properties(cuda_idx)
uuid_str = str(device_properties.uuid)
return uuid_str


def get_nvml_id_by_cuda_uuid(cuda_id):
errors = []
func_list = [
(_device_to_uuid_torch, "torch.cuda"),
(_device_to_uuid_cuda, "cuda-python"),
]
for get_uuid_func, name in func_list:
try:
uuid_str = get_uuid_func(cuda_id)
break
except Exception as e:
errors.append(f"Using {name}: {e}")
else:
raise RuntimeError(f"Failed to get UUID for device {cuda_id}: {errors}")

pynvml.nvmlInit()
encoded = uuid_str.replace("-", "").encode("utf-8")
try:
# Get the NVML device handle using the UUID
handle = pynvml.nvmlDeviceGetHandleByUUID(encoded)
id = pynvml.nvmlDeviceGetIndex(handle)
return id

except pynvml.NVMLError as e:
raise RuntimeError(f"PyNVML Error: {e}")

finally:
pynvml.nvmlShutdown()


def monitor_gpu_memory(queue, running_flag, nvml_id):
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming GPU 0
max_memory_usage = 0

try:
handle = pynvml.nvmlDeviceGetHandleByIndex(nvml_id)
while running_flag.value:
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_usage = info.used / (1024 * 1024) # Convert to MiB
Expand All @@ -18,31 +79,50 @@ def monitor_gpu_memory(queue, running_flag):


class GPUMemoryMonitor:
def __init__(self):
def __init__(self, device: Optional[int] = None):
self.process = None
self.queue = multiprocessing.Queue()
self.running_flag = multiprocessing.Value("b", False)
self.device = torch.device(
device if device is not None else torch.cuda.current_device()
)

pynvml.nvmlInit()
try:
self.nvml_id = get_nvml_id_by_cuda_uuid(self.device.index)
logger.debug(f"Initialized GPU monitor for device {self.device} (NVML ID: {self.nvml_id})")
except Exception as e:
logger.warning(f"Error getting NVML device index: {e}. Using NVML ID 0.")
self.nvml_id = 0
finally:
pynvml.nvmlShutdown()

def start(self):
if self.running_flag.value:
print("GPU monitor is already running.")
logger.warning("Cannot start GPU monitor: it is already running.")
return
self.running_flag.value = True
self.process = multiprocessing.Process(
target=monitor_gpu_memory, args=(self.queue, self.running_flag)
target=monitor_gpu_memory,
args=(self.queue, self.running_flag, self.nvml_id),
)
self.process.start()
logger.info(f"GPU monitor started for device {self.device} (NVML ID: {self.nvml_id})")

def stop(self):
if not self.running_flag.value:
print("GPU monitor is not running.")
logger.warning("Cannot stop GPU monitor: it is not running.")
return
self.running_flag.value = False
self.process.join()
logger.info("GPU monitor stopped")

def get_max_memory_usage(self):
if not self.queue.empty():
return self.queue.get()
max_usage = self.queue.get()
logger.debug(f"Retrieved max GPU memory usage: {max_usage:.2f} MiB")
return max_usage
logger.warning("No memory usage data available")
return 0

def __enter__(self):
Expand All @@ -61,4 +141,4 @@ def __exit__(self, exc_type, exc_value, traceback):

monitor.stop()

print("Maximum GPU memory usage (MB):", monitor.get_max_memory_usage())
logger.info(f"Maximum GPU memory usage (MB): {monitor.get_max_memory_usage()}")