Skip to content

Commit c0f0893

Browse files
committed
Update inference.py
fix inference on amd gpu
1 parent 8e0adfc commit c0f0893

File tree

1 file changed

+76
-11
lines changed

1 file changed

+76
-11
lines changed

optillm/inference.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,42 +402,107 @@ def __init__(self):
402402
self.device_stats = {device: {'memory_used': 0, 'active_models': 0} for device in self.available_devices}
403403

404404
def _detect_devices(self) -> List[str]:
405+
"""Detect available compute devices including AMD GPUs via ROCm"""
405406
devices = ['cpu']
407+
408+
# Check for CUDA (NVIDIA) GPUs
406409
if torch.cuda.is_available():
407-
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
410+
backend = torch.cuda.get_device_properties(0).platform
411+
if backend == 'ROCm':
412+
# AMD GPUs via ROCm
413+
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
414+
logging.info("Detected AMD GPU(s) using ROCm backend")
415+
else:
416+
# NVIDIA GPUs
417+
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
418+
logging.info("Detected NVIDIA GPU(s)")
419+
420+
# Check for Apple M-series GPU
408421
if torch.backends.mps.is_available():
409422
devices.append('mps')
423+
logging.info("Detected Apple M-series GPU")
424+
410425
return devices
411426

412427
def get_optimal_device(self, model_size: int = 0) -> str:
428+
"""Select the optimal device considering AMD GPU support"""
413429
if not self.available_devices:
414430
return 'cpu'
415431

416-
# Prefer CUDA devices if available
432+
# Get CUDA devices (both NVIDIA and AMD via ROCm)
417433
cuda_devices = [d for d in self.available_devices if 'cuda' in d]
434+
418435
if cuda_devices:
419-
# Find CUDA device with most free memory
436+
# Find device with most free memory
420437
max_free_memory = 0
421438
optimal_device = cuda_devices[0]
422439

423-
for device in cuda_devices:
424-
idx = int(device.split(':')[1])
425-
free_memory = torch.cuda.get_device_properties(idx).total_memory - torch.cuda.memory_allocated(idx)
426-
if free_memory > max_free_memory:
427-
max_free_memory = free_memory
428-
optimal_device = device
429-
430-
return optimal_device
440+
try:
441+
for device in cuda_devices:
442+
idx = int(device.split(':')[1])
443+
# Get memory info safely handling both NVIDIA and AMD
444+
try:
445+
total_memory = torch.cuda.get_device_properties(idx).total_memory
446+
used_memory = torch.cuda.memory_allocated(idx)
447+
free_memory = total_memory - used_memory
448+
except Exception as e:
449+
logging.warning(f"Error getting memory info for device {device}: {e}")
450+
continue
451+
452+
if free_memory > max_free_memory:
453+
max_free_memory = free_memory
454+
optimal_device = device
455+
456+
logging.info(f"Selected optimal CUDA device: {optimal_device} with {max_free_memory/1e9:.2f}GB free memory")
457+
return optimal_device
458+
459+
except Exception as e:
460+
logging.error(f"Error selecting optimal CUDA device: {e}")
461+
# Fall back to first CUDA device if memory query fails
462+
return cuda_devices[0]
431463

432464
# Fall back to MPS if available
433465
if 'mps' in self.available_devices:
434466
return 'mps'
435467

468+
# Final fallback to CPU
469+
logging.info("No GPU detected, using CPU")
436470
return 'cpu'
437471

438472
def track_device_usage(self, device: str, memory_delta: int):
473+
"""Track memory usage for the device"""
439474
if device in self.device_stats:
440475
self.device_stats[device]['memory_used'] += memory_delta
476+
477+
def get_device_info(self, device: str) -> Dict[str, Any]:
478+
"""Get detailed information about a device"""
479+
info = {
480+
'type': 'cpu',
481+
'memory_total': None,
482+
'memory_used': None,
483+
'memory_free': None
484+
}
485+
486+
if 'cuda' in device:
487+
try:
488+
idx = int(device.split(':')[1])
489+
props = torch.cuda.get_device_properties(idx)
490+
info.update({
491+
'type': 'gpu',
492+
'name': props.name,
493+
'backend': 'ROCm' if hasattr(props, 'platform') and props.platform == 'ROCm' else 'CUDA',
494+
'compute_capability': f"{props.major}.{props.minor}",
495+
'memory_total': props.total_memory,
496+
'memory_used': torch.cuda.memory_allocated(idx),
497+
'memory_free': props.total_memory - torch.cuda.memory_allocated(idx)
498+
})
499+
except Exception as e:
500+
logging.warning(f"Error getting device info for {device}: {e}")
501+
502+
elif device == 'mps':
503+
info['type'] = 'mps'
504+
505+
return info
441506

442507
class ModelManager:
443508
def __init__(self, cache_manager: CacheManager, device_manager: DeviceManager):

0 commit comments

Comments
 (0)