@@ -42,6 +42,7 @@ def __init__(self):
4242 self .pix2pix_model = None
4343 self .depth_model_type = None
4444 self .device = None # Target device, the model may be swapped from VRAM into RAM.
45+ self .offloaded = False # True means current device is not the target device
4546
4647 # Extra stuff
4748 self .resize_mode = None
@@ -53,9 +54,10 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
5354 self .unload_models ()
5455 return
5556 # Certain optimisations are irreversible and not device-agnostic, thus changing device requires reloading
56- if model_type != self .depth_model_type or boost != self .pix2pix_model is not None or device != self .device :
57+ if model_type != self .depth_model_type or boost != ( self .pix2pix_model is not None ) or device != self .device :
5758 self .unload_models ()
5859 self .load_models (model_type , device , boost )
60+ self .reload ()
5961
6062 def load_models (self , model_type , device : torch .device , boost : bool ):
6163 """Ensure that the depth model is loaded"""
@@ -236,11 +238,23 @@ def get_default_net_size(model_type):
236238 return sizes [model_type ]
237239 return [512 , 512 ]
238240
239- def swap_to_cpu_memory (self ):
241+ def offload (self ):
242+ """Move to RAM to conserve VRAM"""
243+ if self .device != torch .device ('cpu' ) and not self .offloaded :
244+ self .move_models_to (torch .device ('cpu' ))
245+ self .offloaded = True
246+
247+ def reload (self ):
248+ """Undoes offload"""
249+ if self .offloaded :
250+ self .move_models_to (self .device )
251+ self .offloaded = True
252+
253+ def move_models_to (self , device ):
240254 if self .depth_model is not None :
241- self .depth_model .to (torch . device ( 'cpu' ) )
255+ self .depth_model .to (device )
242256 if self .pix2pix_model is not None :
243- self .pix2pix_model .to (torch . device ( 'cpu' ) )
257+ self .pix2pix_model .to (device )
244258
245259 def unload_models (self ):
246260 if self .depth_model is not None or self .pix2pix_model is not None :
0 commit comments