diff --git a/tensorrt_llm/profiler.py b/tensorrt_llm/profiler.py index 5d411caaaea..313c30a37e6 100644 --- a/tensorrt_llm/profiler.py +++ b/tensorrt_llm/profiler.py @@ -44,15 +44,19 @@ "A required package 'pynvml' is not installed. Will not " "monitor the device memory usages. Please install the package " "first, e.g, 'pip install pynvml>=11.5.0'.") -elif pynvml.__version__ < '11.5.0': - logger.warning(f'Found pynvml=={pynvml.__version__}. Please use ' - f'pynvml>=11.5.0 to get accurate memory usage') - # Support legacy pynvml. Note that an old API could return - # wrong GPU memory usage. - _device_get_memory_info_fn = pynvml.nvmlDeviceGetMemoryInfo else: - _device_get_memory_info_fn = partial(pynvml.nvmlDeviceGetMemoryInfo, - version=pynvml.nvmlMemory_v2) + pynvml.nvmlInit() + driver_version = pynvml.nvmlSystemGetDriverVersion() + if pynvml.__version__ < '11.5.0' or driver_version < '526': + logger.warning(f'Found pynvml=={pynvml.__version__} and cuda driver version {driver_version}.' + ' Please use pynvml>=11.5.0 and cuda driver>=526 to get accurate memory usage') + # Support legacy pynvml. Note that an old API could return + # wrong GPU memory usage. + _device_get_memory_info_fn = pynvml.nvmlDeviceGetMemoryInfo + else: + _device_get_memory_info_fn = partial(pynvml.nvmlDeviceGetMemoryInfo, + version=pynvml.nvmlMemory_v2) + pynvml.nvmlShutdown() class Timer: