Skip to content

Commit f35d5ae

Browse files
committed
Post-refactor fixes
1 parent 72517c4 commit f35d5ae

File tree

5 files changed

+28
-23
lines changed

5 files changed

+28
-23
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ video by [@graemeniedermayer](https://github.com/graemeniedermayer), more exampl
2121
images generated by [@semjon00](https://github.com/semjon00) from CC0 photos, more examples [here](https://github.com/thygate/stable-diffusion-webui-depthmap-script/pull/56#issuecomment-1367596463).
2222

2323
## Changelog
24-
* v0.3.13
25-
* Large code refactor
26-
* Improved interface
27-
* Slightly changed the behaviour of various options
24+
* v0.4.0 large code refactor
25+
* UI improvements
26+
* slightly changed the behaviour of various options
27+
* extension may partially work even if some of the dependencies are unmet
2828
* v0.3.12
2929
* Fixed stereo image generation
3030
* Other bugfixes

scripts/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def core_generation_funnel(outpath, inputimages, inputdepthmaps, inputnames, inp
230230

231231
if show_heat:
232232
from dzoedepth.utils.misc import colorize
233-
heatmap = colorize(img_output, cmap='inferno')
233+
heatmap = Image.fromarray(colorize(img_output, cmap='inferno'))
234234
generated_images[count]['heatmap'] = heatmap
235235

236236
if gen_stereo:

scripts/depthmap_generation.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from PIL import Image
44
from torchvision.transforms import Compose, transforms
55

6+
# TODO: depthmap_generation should not depend on WebUI
67
from modules import shared, devices
78
from modules.shared import opts, cmd_opts
89

@@ -29,7 +30,6 @@
2930
from pix2pix.options.test_options import TestOptions
3031
from pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
3132

32-
3333
# zoedepth
3434
from dzoedepth.models.builder import build_model
3535
from dzoedepth.utils.config import get_config
@@ -59,9 +59,6 @@ def ensure_models(self, model_type, device: torch.device, boost: bool):
5959

6060
def load_models(self, model_type, device: torch.device, boost: bool):
6161
"""Ensure that the depth model is loaded"""
62-
# TODO: supply correct values for zoedepth
63-
net_width = 512
64-
net_height = 512
6562

6663
# model path and name
6764
model_dir = "./models/midas"
@@ -171,22 +168,21 @@ def load_models(self, model_type, device: torch.device, boost: bool):
171168
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
172169
)
173170

171+
# When loading, zoedepth models will report the default net size.
172+
# It will be overridden by the generation settings.
174173
elif model_type == 7: # zoedepth_n
175174
print("zoedepth_n\n")
176175
conf = get_config("zoedepth", "infer")
177-
conf.img_size = [net_width, net_height]
178176
model = build_model(conf)
179177

180178
elif model_type == 8: # zoedepth_k
181179
print("zoedepth_k\n")
182180
conf = get_config("zoedepth", "infer", config_version="kitti")
183-
conf.img_size = [net_width, net_height]
184181
model = build_model(conf)
185182

186183
elif model_type == 9: # zoedepth_nk
187184
print("zoedepth_nk\n")
188185
conf = get_config("zoedepth_nk", "infer")
189-
conf.img_size = [net_width, net_height]
190186
model = build_model(conf)
191187

192188
model.eval() # prepare for evaluation
@@ -224,12 +220,16 @@ def load_models(self, model_type, device: torch.device, boost: bool):
224220
def get_default_net_size(self, model_type):
225221
# TODO: fill in, use in the GUI
226222
sizes = {
223+
0: [448, 448],
227224
1: [512, 512],
228225
2: [384, 384],
229226
3: [384, 384],
230227
4: [384, 384],
231228
5: [384, 384],
232229
6: [256, 256],
230+
7: [384, 512],
231+
8: [384, 768],
232+
9: [384, 512]
233233
}
234234
if model_type in sizes:
235235
return sizes[model_type]
@@ -254,8 +254,9 @@ def unload_models(self):
254254
self.device = None
255255

256256
def get_raw_prediction(self, input, net_width, net_height):
257-
"""Get prediction from the model currently loaded by the class.
257+
"""Get prediction from the model currently loaded by the ModelHolder object.
258258
If boost is enabled, net_width and net_height will be ignored."""
259+
# TODO: supply net size for zoedepth
259260
global device
260261
device = self.device
261262
# input image
@@ -264,17 +265,14 @@ def get_raw_prediction(self, input, net_width, net_height):
264265
if self.pix2pix_model is None:
265266
if self.depth_model_type == 0:
266267
raw_prediction = estimateleres(img, self.depth_model, net_width, net_height)
267-
raw_prediction_invert = True
268268
elif self.depth_model_type in [7, 8, 9]:
269269
raw_prediction = estimatezoedepth(input, self.depth_model, net_width, net_height)
270-
raw_prediction_invert = True
271270
else:
272271
raw_prediction = estimatemidas(img, self.depth_model, net_width, net_height,
273272
self.resize_mode, self.normalization)
274-
raw_prediction_invert = False
275273
else:
276274
raw_prediction = estimateboost(img, self.depth_model, self.depth_model_type, self.pix2pix_model)
277-
raw_prediction_invert = False
275+
raw_prediction_invert = self.depth_model_type in [0, 7, 8, 9]
278276
return raw_prediction, raw_prediction_invert
279277

280278

scripts/interface_webui.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,19 @@ def main_ui_panel(is_depth_tab):
102102
with gr.Group():
103103
with gr.Row():
104104
inp += "gen_mesh", gr.Checkbox(
105-
label="Generate simple 3D mesh. "
106-
"(Fast, accurate only with ZoeDepth models and no boost, no custom maps)",
107-
value=False, visible=True)
105+
label="Generate simple 3D mesh", value=False, visible=True)
108106
with gr.Row(visible=False) as mesh_options_row_0:
107+
gr.Label(value="Generates fast, accurate only with ZoeDepth models and no boost, no custom maps")
109108
inp += "mesh_occlude", gr.Checkbox(label="Remove occluded edges", value=True, visible=True)
110109
inp += "mesh_spherical", gr.Checkbox(label="Equirectangular projection", value=False, visible=True)
111110

112111
if is_depth_tab:
113112
with gr.Group():
114113
with gr.Row():
115114
inp += "inpaint", gr.Checkbox(
116-
label="Generate 3D inpainted mesh. (Sloooow, required for generating videos)", value=False)
115+
label="Generate 3D inpainted mesh", value=False)
117116
with gr.Group(visible=False) as inpaint_options_row_0:
117+
gr.Label("Generation is sloooow, required for generating videos")
118118
inp += "inpaint_vids", gr.Checkbox(
119119
label="Generate 4 demo videos with 3D inpainted mesh.", value=False)
120120
gr.HTML("More options for generating video can be found in the Generate video tab")

scripts/stereoimage_generation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from numba import njit, prange
1+
try:
2+
from numba import njit, prange
3+
except Exception as e:
4+
print(f"WARINING! Numba failed to import! Stereoimage generation will be much slower! ({str(e)})")
5+
from builtins import range as prange
6+
def njit(parallel=False):
7+
def Inner(func): return lambda *args, **kwargs: func(*args, **kwargs)
8+
return Inner
29
import numpy as np
310
from PIL import Image
411

@@ -73,7 +80,7 @@ def apply_stereo_divergence(original_image, depth, divergence, separation, fill_
7380
)
7481

7582

76-
@njit
83+
@njit(parallel=False)
7784
def apply_stereo_divergence_naive(
7885
original_image, normalized_depth, divergence_px: float, separation_px: float, fill_technique):
7986
h, w, c = original_image.shape

0 commit comments

Comments
 (0)