Skip to content

Commit cc55be2

Browse files
committed
Post-refactor fixes vol. 2
* Reload model before generation, if it is offloaded to CPU * Load different model if boost got selected
1 parent 88aa86f commit cc55be2

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

scripts/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,14 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
304304
else:
305305
raise e
306306
finally:
307-
if not (hasattr(opts, 'depthmap_script_keepmodels') and opts.depthmap_script_keepmodels):
307+
if hasattr(opts, 'depthmap_script_keepmodels') and opts.depthmap_script_keepmodels:
308+
model_holder.offload() # Swap to CPU memory
309+
else:
308310
if 'model' in locals():
309311
del model
310312
if 'pix2pixmodel' in locals():
311313
del pix2pix_model
312314
model_holder.unload_models()
313-
else:
314-
model_holder.swap_to_cpu_memory()
315315

316316
gc.collect()
317317
devices.torch_gc()

scripts/depthmap_generation.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self):
4242
self.pix2pix_model = None
4343
self.depth_model_type = None
4444
self.device = None # Target device, the model may be swapped from VRAM into RAM.
45+
self.offloaded = False # True means current device is not the target device
4546

4647
# Extra stuff
4748
self.resize_mode = None
@@ -53,9 +54,10 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
5354
self.unload_models()
5455
return
5556
# Certain optimisations are irreversible and not device-agnostic, thus changing device requires reloading
56-
if model_type != self.depth_model_type or boost != self.pix2pix_model is not None or device != self.device:
57+
if model_type != self.depth_model_type or boost != (self.pix2pix_model is not None) or device != self.device:
5758
self.unload_models()
5859
self.load_models(model_type, device, boost)
60+
self.reload()
5961

6062
def load_models(self, model_type, device: torch.device, boost: bool):
6163
"""Ensure that the depth model is loaded"""
@@ -236,11 +238,23 @@ def get_default_net_size(model_type):
236238
return sizes[model_type]
237239
return [512, 512]
238240

239-
def swap_to_cpu_memory(self):
241+
def offload(self):
242+
"""Move to RAM to conserve VRAM"""
243+
if self.device != torch.device('cpu') and not self.offloaded:
244+
self.move_models_to(torch.device('cpu'))
245+
self.offloaded = True
246+
247+
def reload(self):
248+
"""Undoes offload"""
249+
if self.offloaded:
250+
self.move_models_to(self.device)
251+
self.offloaded = True
252+
253+
def move_models_to(self, device):
240254
if self.depth_model is not None:
241-
self.depth_model.to(torch.device('cpu'))
255+
self.depth_model.to(device)
242256
if self.pix2pix_model is not None:
243-
self.pix2pix_model.to(torch.device('cpu'))
257+
self.pix2pix_model.to(device)
244258

245259
def unload_models(self):
246260
if self.depth_model is not None or self.pix2pix_model is not None:

0 commit comments

Comments
 (0)