@@ -402,42 +402,107 @@ def __init__(self):
402402 self .device_stats = {device : {'memory_used' : 0 , 'active_models' : 0 } for device in self .available_devices }
403403
404404 def _detect_devices (self ) -> List [str ]:
405+ """Detect available compute devices including AMD GPUs via ROCm"""
405406 devices = ['cpu' ]
407+
408+ # Check for CUDA (NVIDIA) GPUs
406409 if torch .cuda .is_available ():
407- devices .extend ([f'cuda:{ i } ' for i in range (torch .cuda .device_count ())])
410+ backend = torch .cuda .get_device_properties (0 ).platform
411+ if backend == 'ROCm' :
412+ # AMD GPUs via ROCm
413+ devices .extend ([f'cuda:{ i } ' for i in range (torch .cuda .device_count ())])
414+ logging .info ("Detected AMD GPU(s) using ROCm backend" )
415+ else :
416+ # NVIDIA GPUs
417+ devices .extend ([f'cuda:{ i } ' for i in range (torch .cuda .device_count ())])
418+ logging .info ("Detected NVIDIA GPU(s)" )
419+
420+ # Check for Apple M-series GPU
408421 if torch .backends .mps .is_available ():
409422 devices .append ('mps' )
423+ logging .info ("Detected Apple M-series GPU" )
424+
410425 return devices
411426
412427 def get_optimal_device (self , model_size : int = 0 ) -> str :
428+ """Select the optimal device considering AMD GPU support"""
413429 if not self .available_devices :
414430 return 'cpu'
415431
416- # Prefer CUDA devices if available
432+ # Get CUDA devices (both NVIDIA and AMD via ROCm)
417433 cuda_devices = [d for d in self .available_devices if 'cuda' in d ]
434+
418435 if cuda_devices :
419- # Find CUDA device with most free memory
436+ # Find device with most free memory
420437 max_free_memory = 0
421438 optimal_device = cuda_devices [0 ]
422439
423- for device in cuda_devices :
424- idx = int (device .split (':' )[1 ])
425- free_memory = torch .cuda .get_device_properties (idx ).total_memory - torch .cuda .memory_allocated (idx )
426- if free_memory > max_free_memory :
427- max_free_memory = free_memory
428- optimal_device = device
429-
430- return optimal_device
440+ try :
441+ for device in cuda_devices :
442+ idx = int (device .split (':' )[1 ])
443+ # Get memory info safely handling both NVIDIA and AMD
444+ try :
445+ total_memory = torch .cuda .get_device_properties (idx ).total_memory
446+ used_memory = torch .cuda .memory_allocated (idx )
447+ free_memory = total_memory - used_memory
448+ except Exception as e :
449+ logging .warning (f"Error getting memory info for device { device } : { e } " )
450+ continue
451+
452+ if free_memory > max_free_memory :
453+ max_free_memory = free_memory
454+ optimal_device = device
455+
456+ logging .info (f"Selected optimal CUDA device: { optimal_device } with { max_free_memory / 1e9 :.2f} GB free memory" )
457+ return optimal_device
458+
459+ except Exception as e :
460+ logging .error (f"Error selecting optimal CUDA device: { e } " )
461+ # Fall back to first CUDA device if memory query fails
462+ return cuda_devices [0 ]
431463
432464 # Fall back to MPS if available
433465 if 'mps' in self .available_devices :
434466 return 'mps'
435467
468+ # Final fallback to CPU
469+ logging .info ("No GPU detected, using CPU" )
436470 return 'cpu'
437471
438472 def track_device_usage (self , device : str , memory_delta : int ):
473+ """Track memory usage for the device"""
439474 if device in self .device_stats :
440475 self .device_stats [device ]['memory_used' ] += memory_delta
476+
477+ def get_device_info (self , device : str ) -> Dict [str , Any ]:
478+ """Get detailed information about a device"""
479+ info = {
480+ 'type' : 'cpu' ,
481+ 'memory_total' : None ,
482+ 'memory_used' : None ,
483+ 'memory_free' : None
484+ }
485+
486+ if 'cuda' in device :
487+ try :
488+ idx = int (device .split (':' )[1 ])
489+ props = torch .cuda .get_device_properties (idx )
490+ info .update ({
491+ 'type' : 'gpu' ,
492+ 'name' : props .name ,
493+ 'backend' : 'ROCm' if hasattr (props , 'platform' ) and props .platform == 'ROCm' else 'CUDA' ,
494+ 'compute_capability' : f"{ props .major } .{ props .minor } " ,
495+ 'memory_total' : props .total_memory ,
496+ 'memory_used' : torch .cuda .memory_allocated (idx ),
497+ 'memory_free' : props .total_memory - torch .cuda .memory_allocated (idx )
498+ })
499+ except Exception as e :
500+ logging .warning (f"Error getting device info for { device } : { e } " )
501+
502+ elif device == 'mps' :
503+ info ['type' ] = 'mps'
504+
505+ return info
441506
442507class ModelManager :
443508 def __init__ (self , cache_manager : CacheManager , device_manager : DeviceManager ):
0 commit comments