Skip to content

Commit 336bcea

Browse files
committed
Post-refactor fixes vol. 2
* Reload model before generation, if it is offloaded to CPU * Load model if boost got selected * Do not try to offload pix2pix * Net dimensions are multiple of 32 regardless of match size * Change the default net size to default net size of the default model
1 parent 88aa86f commit 336bcea

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

scripts/core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
165165
else:
166166
# override net size (size may be different for different images)
167167
if match_size:
168-
net_width, net_height = inputimages[count].width, inputimages[count].height
168+
# Round up to a multiple of 32 to avoid potential issues
169+
net_width = (inputimages[count].width + 31) // 32 * 32
170+
net_height = (inputimages[count].height + 31) // 32 * 32
169171
raw_prediction, raw_prediction_invert = \
170172
model_holder.get_raw_prediction(inputimages[count], net_width, net_height)
171173

@@ -304,14 +306,14 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
304306
else:
305307
raise e
306308
finally:
307-
if not (hasattr(opts, 'depthmap_script_keepmodels') and opts.depthmap_script_keepmodels):
309+
if hasattr(opts, 'depthmap_script_keepmodels') and opts.depthmap_script_keepmodels:
310+
model_holder.offload() # Swap to CPU memory
311+
else:
308312
if 'model' in locals():
309313
del model
310314
if 'pix2pixmodel' in locals():
311315
del pix2pix_model
312316
model_holder.unload_models()
313-
else:
314-
model_holder.swap_to_cpu_memory()
315317

316318
gc.collect()
317319
devices.torch_gc()

scripts/depthmap_generation.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self):
4242
self.pix2pix_model = None
4343
self.depth_model_type = None
4444
self.device = None # Target device, the model may be swapped from VRAM into RAM.
45+
self.offloaded = False # True means current device is not the target device
4546

4647
# Extra stuff
4748
self.resize_mode = None
@@ -53,9 +54,10 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
5354
self.unload_models()
5455
return
5556
# Certain optimisations are irreversible and not device-agnostic, thus changing device requires reloading
56-
if model_type != self.depth_model_type or boost != self.pix2pix_model is not None or device != self.device:
57+
if model_type != self.depth_model_type or boost != (self.pix2pix_model is not None) or device != self.device:
5758
self.unload_models()
5859
self.load_models(model_type, device, boost)
60+
self.reload()
5961

6062
def load_models(self, model_type, device: torch.device, boost: bool):
6163
"""Ensure that the depth model is loaded"""
@@ -236,11 +238,24 @@ def get_default_net_size(model_type):
236238
return sizes[model_type]
237239
return [512, 512]
238240

239-
def swap_to_cpu_memory(self):
241+
def offload(self):
242+
"""Move to RAM to conserve VRAM"""
243+
if self.device != torch.device('cpu') and not self.offloaded:
244+
self.move_models_to(torch.device('cpu'))
245+
self.offloaded = True
246+
247+
def reload(self):
248+
"""Undoes offload"""
249+
if self.offloaded:
250+
self.move_models_to(self.device)
251+
self.offloaded = True
252+
253+
def move_models_to(self, device):
240254
if self.depth_model is not None:
241-
self.depth_model.to(torch.device('cpu'))
255+
self.depth_model.to(device)
242256
if self.pix2pix_model is not None:
243-
self.pix2pix_model.to(torch.device('cpu'))
257+
pass
258+
# TODO: pix2pix offloading not implemented
244259

245260
def unload_models(self):
246261
if self.depth_model is not None or self.pix2pix_model is not None:

scripts/interface_webui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def main_ui_panel(is_depth_tab):
4949
with gr.Group(visible=False) as options_depend_on_boost:
5050
inp += 'match_size', gr.Checkbox(label="Match net size to input size", value=False)
5151
with gr.Row(visible=False) as options_depend_on_match_size:
52-
inp += 'net_width', gr.Slider(minimum=64, maximum=2048, step=64, label='Net width', value=512)
53-
inp += 'net_height', gr.Slider(minimum=64, maximum=2048, step=64, label='Net height', value=512)
52+
inp += 'net_width', gr.Slider(minimum=64, maximum=2048, step=64, label='Net width', value=448)
53+
inp += 'net_height', gr.Slider(minimum=64, maximum=2048, step=64, label='Net height', value=448)
5454

5555
with gr.Group():
5656
with gr.Row():

0 commit comments

Comments
 (0)