Skip to content

Commit b8dbdc5

Browse files
author
Harri Smått
committed
Convert from Depth Pro default 1536x1536 implementation to 1024x1024 tensor CoreML programs
1 parent b2cd0d5 commit b8dbdc5

File tree

4 files changed

+187
-6
lines changed

4 files changed

+187
-6
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
checkpoints
3+
out
4+

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This fork adds `convert_to_coreml.py` script to convert original Depth Pro model to CoreML programs
2+
13
## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
24

35
This software project accompanies the research paper:

convert_to_coreml.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import coremltools as ct
2+
import logging
3+
import math
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
import torch.optim as optim
8+
9+
from matplotlib import pyplot as plt
10+
from typing import Tuple
11+
12+
from src.depth_pro.depth_pro import (
13+
create_model_and_transforms,
14+
create_backbone_model,
15+
DEFAULT_MONODEPTH_CONFIG_DICT
16+
)
17+
from src.depth_pro.network.fov import FOVNetwork
18+
from src.depth_pro.network.vit import resize_vit, resize_patch_embed
19+
from src.depth_pro.utils import load_rgb
20+
21+
from torchvision.transforms import (
22+
Compose,
23+
ConvertImageDtype,
24+
Lambda,
25+
Normalize,
26+
ToTensor
27+
)
28+
29+
class DepthProRun(nn.Module):
30+
def __init__(self, transform: nn.Module, encoder: nn.Module, decoder: nn.Module, depth: nn.Module):
31+
super().__init__()
32+
self.transform = transform
33+
self.encoder = encoder
34+
self.decoder = decoder
35+
self.depth = depth
36+
37+
def forward(self, x: torch.Tensor) -> torch.Tensor:
38+
if x.shape[0] == 3:
39+
x = x.unsqueeze(0)
40+
image = self.transform(x)
41+
encodings = self.encoder(image)
42+
features, features_0 = self.decoder(encodings)
43+
depth = self.depth([image, features, features_0])
44+
return depth
45+
46+
class Depth(nn.Module):
47+
def __init__(self, head: nn.Module, fov: nn.Module):
48+
super(Depth, self).__init__()
49+
self.head = head
50+
self.fov = fov
51+
52+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
53+
x = inputs[0]
54+
features = inputs[1]
55+
features_0 = inputs[2]
56+
_, _, H, W = x.shape
57+
# using default size 1536 until fov_encoder resizing succeeds
58+
# 1024 is the expected size to compare against then
59+
if H != 1536 or W != 1536:
60+
x = nn.functional.interpolate(
61+
x,
62+
size=(1536, 1536),
63+
mode="bilinear",
64+
align_corners=False,
65+
)
66+
# this is needed until resizing fov_encoder succeeds
67+
# the surrent resized size (32, 32) is correct here then
68+
features_0 = nn.functional.interpolate(
69+
features_0,
70+
size=(48, 48),
71+
mode="bilinear",
72+
align_corners=False,
73+
)
74+
canonical_inverse_depth = self.head(features)
75+
fov_deg = self.fov.forward(x, features_0.detach())
76+
f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0)
77+
inverse_depth = canonical_inverse_depth * f_px
78+
depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4)
79+
return depth
80+
81+
class Interpolate(nn.Module):
82+
def __init__(self, size, mode):
83+
super(Interpolate, self).__init__()
84+
self.interp = nn.functional.interpolate
85+
self.size = size
86+
self.mode = mode
87+
88+
def forward(self, x: torch.Tensor) -> torch.Tensor:
89+
x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
90+
return x
91+
92+
def save_mlpackage(G, shapes, name):
93+
G.eval()
94+
G_inputs = []
95+
convert_inputs = []
96+
for shape in shapes:
97+
G_inputs.append(torch.randn(shape))
98+
convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float16))
99+
G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs])
100+
G_model = ct.convert(
101+
G_trace,
102+
inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs],
103+
minimum_deployment_target=ct.target.macOS15,
104+
compute_precision=ct.precision.FLOAT16,
105+
compute_units=ct.ComputeUnit.CPU_AND_NE
106+
)
107+
G_model.save("out/" + name + ".mlpackage")
108+
109+
def create_scaled_model() -> Tuple[nn.Module, nn.Module, nn.Module]:
110+
# from run.py
111+
model, _ = create_model_and_transforms(
112+
device=torch.device("cpu"),
113+
precision=torch.float32,
114+
)
115+
116+
new_img_size = (256, 256)
117+
# resize to 256x4 = 1024x1024 input image
118+
model.encoder.patch_encoder = resize_patch_embed(model.encoder.patch_encoder)
119+
model.encoder.patch_encoder = resize_vit(model.encoder.patch_encoder, img_size=new_img_size)
120+
model.encoder.image_encoder = resize_patch_embed(model.encoder.image_encoder)
121+
model.encoder.image_encoder = resize_vit(model.encoder.image_encoder, img_size=new_img_size)
122+
model.encoder.out_size = int(
123+
model.encoder.patch_encoder.patch_embed.img_size[0] // model.encoder.patch_encoder.patch_embed.patch_size[0]
124+
)
125+
126+
# this is still under works to resize fov_encoder to 1024x1024 size too
127+
# fov_encoder, _ = create_backbone_model(preset = DEFAULT_MONODEPTH_CONFIG_DICT.fov_encoder_preset)
128+
# fov_encoder = resize_patch_embed(fov_encoder)
129+
# fov_encoder = resize_vit(fov_encoder, img_size=new_img_size)
130+
# model.fov = FOVNetwork(num_features=model.decoder.dim_decoder, fov_encoder=fov_encoder)
131+
132+
# from depth_pro.py
133+
transform = nn.Sequential(
134+
#[
135+
#ToTensor(),
136+
#Lambda(lambda x: x.to(device)),
137+
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
138+
Interpolate(
139+
size=(model.img_size, model.img_size),
140+
mode="bilinear"
141+
),
142+
ConvertImageDtype(torch.float32),
143+
#]
144+
)
145+
146+
depth = Depth(model.head, model.fov)
147+
return transform, model, depth
148+
149+
def load_and_show_example(transform: nn.Module, model: nn.Module, depth: nn.Module):
150+
image, _, _ = load_rgb("data/example.jpg")
151+
depth_pro_run = DepthProRun(transform, model.encoder, model.decoder, depth)
152+
153+
depth_pro = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), depth_pro_run])
154+
depth_map = depth_pro(image).detach().cpu().numpy().squeeze()
155+
156+
plt.ion()
157+
fig = plt.figure()
158+
ax_rgb = fig.add_subplot(121)
159+
ax_disp = fig.add_subplot(122)
160+
ax_rgb.imshow(image)
161+
ax_disp.imshow(depth_map, cmap="turbo")
162+
fig.canvas.draw()
163+
fig.canvas.flush_events()
164+
plt.show(block=True)
165+
166+
def save_coreml_packages(transform: nn.Module, model: nn.Module, depth: nn.Module):
167+
save_mlpackage(transform, [[1, 3, 1024, 1024]], "DepthPro_transform")
168+
save_mlpackage(model.encoder, [[1, 3, 1024, 1024]], "DepthPro_encoder")
169+
save_mlpackage(model.decoder, [[1, 256, 512, 512], [1, 256, 256, 256], [1, 512, 128, 128], [1, 1024, 64, 64], [1, 1024, 32, 32]], "DepthPro_decoder")
170+
save_mlpackage(depth, [[1, 3, 1024, 1024], [1, 256, 512, 512], [1, 256, 32, 32]], "DepthPro_depth")
171+
172+
if __name__ == "__main__":
173+
transform, model, depth = create_scaled_model()
174+
load_and_show_example(transform, model, depth)
175+
save_coreml_packages(transform, model, depth)

src/depth_pro/network/encoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _create_pyramid(
169169

170170
def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
171171
"""Split the input into small patches with sliding window."""
172-
patch_size = 384
172+
patch_size = 256
173173
patch_stride = int(patch_size * (1 - overlap_ratio))
174174

175175
image_size = x.shape[-1]
@@ -276,7 +276,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
276276
self.out_size,
277277
)
278278
x_latent0_features = self.merge(
279-
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
279+
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2
280280
)
281281

282282
x_latent1_encodings = self.reshape_feature(
@@ -285,21 +285,21 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
285285
self.out_size,
286286
)
287287
x_latent1_features = self.merge(
288-
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
288+
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2
289289
)
290290

291291
# Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
292292
x0_encodings, x1_encodings, x2_encodings = torch.split(
293293
x_pyramid_encodings,
294-
[len(x0_patches), len(x1_patches), len(x2_patches)],
294+
[x0_patches.shape[0], x1_patches.shape[0], x2_patches.shape[0]],
295295
dim=0,
296296
)
297297

298298
# 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
299-
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
299+
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=2)
300300

301301
# 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
302-
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
302+
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=4)
303303

304304
# 24x24 feature maps.
305305
x2_features = x2_encodings

0 commit comments

Comments
 (0)