Skip to content

Commit aa18879

Browse files
author
Harri Smått
committed
Convert from Depth Pro default 1536x1536 implementation to 768x768 tensor CoreML packages
1 parent b2cd0d5 commit aa18879

File tree

7 files changed

+329
-5
lines changed

7 files changed

+329
-5
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
checkpoints
3+
out
4+

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This fork adds `convert_to_coreml.py` script to convert original Depth Pro model to CoreML programs
2+
13
## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
24

35
This software project accompanies the research paper:

convert_to_coreml.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import logging
2+
import math
3+
import numpy as np
4+
5+
import coremltools as ct
6+
from coremltools.converters.mil import register_torch_op
7+
from coremltools.converters.mil.frontend.torch.ops import upsample_bilinear2d
8+
from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op
9+
10+
import torch
11+
from torch import nn
12+
from torch.nn import functional as F
13+
14+
from matplotlib import pyplot as plt
15+
from typing import Tuple
16+
17+
from src.depth_pro.depth_pro import (
18+
create_model_and_transforms,
19+
create_backbone_model,
20+
DepthProConfig
21+
)
22+
from src.depth_pro.network.decoder import MultiresConvDecoder
23+
from src.depth_pro.network.encoder import DepthProEncoder
24+
from src.depth_pro.network.fov import FOVNetwork
25+
from src.depth_pro.network.vit import resize_vit, resize_patch_embed
26+
from src.depth_pro.utils import load_rgb
27+
28+
from torchvision.transforms import (
29+
Compose,
30+
ConvertImageDtype,
31+
Lambda,
32+
Normalize,
33+
ToTensor
34+
)
35+
36+
"""
37+
example.jpg fov_deg =
38+
default 1536x1536: 48.4297
39+
scaled 1024x1024: 49.8382
40+
"""
41+
42+
class DepthDecoder(nn.Module):
43+
def __init__(self, head: nn.Module, fov: FOVNetwork):
44+
super(DepthDecoder, self).__init__()
45+
self.head = head
46+
self.fov = fov
47+
48+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
49+
x = inputs[0]
50+
features = inputs[1]
51+
features_0 = inputs[2]
52+
53+
# execute fov.forward locally with a different scale_factor
54+
# fov_deg = self.fov.forward(x, features_0.detach())
55+
if hasattr(self.fov, "encoder"):
56+
x = F.interpolate(
57+
x,
58+
size=None,
59+
# result size needs to be 384
60+
scale_factor=0.25,
61+
mode="bilinear",
62+
align_corners=False,
63+
)
64+
x = self.fov.encoder(x)[:, 1:].permute(0, 2, 1)
65+
lowres_feature = self.fov.downsample(features_0.detach())
66+
x = x.reshape_as(lowres_feature) + lowres_feature
67+
else:
68+
x = features_0.detach()
69+
70+
fov_deg = self.fov.head(x)
71+
f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0)
72+
73+
canonical_inverse_depth = self.head(features)
74+
inverse_depth = canonical_inverse_depth * f_px
75+
depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4)
76+
return depth
77+
78+
class DepthProScaled(nn.Module):
79+
def __init__(self, transform: nn.Module, encoder: DepthProEncoder, decoder: MultiresConvDecoder, depth: DepthDecoder):
80+
super().__init__()
81+
self.transform = transform
82+
self.encoder = encoder
83+
self.decoder = decoder
84+
self.depth = depth
85+
86+
def forward(self, x: torch.Tensor) -> torch.Tensor:
87+
if x.shape[0] == 3:
88+
x = x.unsqueeze(0)
89+
image = self.transform(x)
90+
encodings = self.encoder(image)
91+
features, features_0 = self.decoder(encodings)
92+
depth = self.depth([image, features, features_0])
93+
return depth
94+
95+
class Interpolate(nn.Module):
96+
def __init__(self, size, mode):
97+
super(Interpolate, self).__init__()
98+
self.size = size
99+
self.mode = mode
100+
101+
def forward(self, x: torch.Tensor) -> torch.Tensor:
102+
x = F.interpolate(x, size=self.size, mode=self.mode, align_corners=False)
103+
return x
104+
105+
def create_scaled_model() -> DepthProScaled:
106+
config = DepthProConfig(
107+
patch_encoder_preset="dinov2l16_192",
108+
image_encoder_preset="dinov2l16_192",
109+
checkpoint_uri="./checkpoints/depth_pro.pt",
110+
decoder_features=256,
111+
use_fov_head=True,
112+
fov_encoder_preset="dinov2l16_192",
113+
)
114+
115+
patch_encoder, patch_encoder_config = create_backbone_model(preset = config.patch_encoder_preset)
116+
image_encoder, _ = create_backbone_model(preset = config.image_encoder_preset)
117+
fov_encoder, _ = create_backbone_model(preset = config.fov_encoder_preset)
118+
#fov_encoder = None
119+
120+
dims_encoder = patch_encoder_config.encoder_feature_dims
121+
hook_block_ids = patch_encoder_config.encoder_feature_layer_ids
122+
encoder = DepthProEncoder(
123+
dims_encoder=dims_encoder,
124+
patch_encoder=patch_encoder,
125+
image_encoder=image_encoder,
126+
hook_block_ids=hook_block_ids,
127+
decoder_features=config.decoder_features,
128+
)
129+
130+
decoder = MultiresConvDecoder(
131+
dims_encoder=[config.decoder_features] + list(encoder.dims_encoder),
132+
dim_decoder=config.decoder_features,
133+
)
134+
135+
num_features = config.decoder_features
136+
fov = FOVNetwork(num_features=num_features, fov_encoder=fov_encoder)
137+
# Create FOV head.
138+
fov_head0 = [
139+
nn.Conv2d(
140+
num_features, num_features // 2, kernel_size=3, stride=2, padding=3
141+
), # 128 x 24 x 24
142+
nn.ReLU(True),
143+
]
144+
fov_head = [
145+
nn.Conv2d(
146+
num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=3
147+
), # 64 x 12 x 12
148+
nn.ReLU(True),
149+
nn.Conv2d(
150+
num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=3
151+
), # 32 x 6 x 6
152+
nn.ReLU(True),
153+
nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0),
154+
]
155+
if fov_encoder is None:
156+
fov_head = fov_head0 + fov_head
157+
fov.head = nn.Sequential(*fov_head)
158+
#fov = None
159+
160+
last_dims = (32, 1)
161+
dim_decoder = config.decoder_features
162+
head = nn.Sequential(
163+
nn.Conv2d(
164+
dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1
165+
),
166+
nn.ConvTranspose2d(
167+
in_channels=dim_decoder // 2,
168+
out_channels=dim_decoder // 2,
169+
kernel_size=2,
170+
stride=2,
171+
padding=0,
172+
bias=True,
173+
),
174+
nn.Conv2d(
175+
dim_decoder // 2,
176+
last_dims[0],
177+
kernel_size=3,
178+
stride=1,
179+
padding=1,
180+
),
181+
nn.ReLU(True),
182+
nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
183+
nn.ReLU(),
184+
)
185+
186+
# Set the final convolution layer's bias to be 0.
187+
head[4].bias.data.fill_(0)
188+
189+
# from depth_pro.py
190+
transform = nn.Sequential(
191+
#[
192+
#ToTensor(),
193+
#Lambda(lambda x: x.to(device)),
194+
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
195+
#Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]),
196+
Interpolate(
197+
size=(encoder.img_size, encoder.img_size),
198+
mode="bilinear"
199+
),
200+
ConvertImageDtype(torch.float32),
201+
#]
202+
)
203+
204+
depth = DepthDecoder(head, fov)
205+
load_state_dict(depth, config)
206+
207+
model = DepthProScaled(transform, encoder, decoder, depth)
208+
load_state_dict(model, config)
209+
210+
return model
211+
212+
def load_state_dict(model: nn.Module, config: DepthProConfig):
213+
checkpoint_uri = config.checkpoint_uri
214+
state_dict = torch.load(checkpoint_uri, map_location="cpu")
215+
_, _ = model.load_state_dict(
216+
state_dict=state_dict, strict=False
217+
)
218+
219+
def load_and_show_example(model: DepthProScaled):
220+
image, _, _ = load_rgb("data/example.jpg")
221+
model_run = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), model])
222+
depth_map = model_run(image).detach().cpu().numpy().squeeze()
223+
224+
plt.ion()
225+
fig = plt.figure()
226+
ax_rgb = fig.add_subplot(121)
227+
ax_disp = fig.add_subplot(122)
228+
ax_rgb.imshow(image)
229+
ax_disp.imshow(depth_map, cmap="turbo")
230+
fig.canvas.draw()
231+
fig.canvas.flush_events()
232+
plt.show(block=True)
233+
234+
def save_coreml_packages(model: DepthProScaled):
235+
save_mlpackage(model.transform, [[1, 3, 1080, 1920]], "DepthPro_transform", True)
236+
save_mlpackage(model.encoder, [[1, 3, 768, 768]], "DepthPro_encoder")
237+
save_mlpackage(model.decoder, [[1, 256, 288, 288], [1, 256, 144, 144], [1, 512, 72, 72], [1, 1024, 24, 24], [1, 1024, 24, 24]], "DepthPro_decoder")
238+
save_mlpackage(model.depth, [[1, 3, 768, 768], [1, 256, 288, 288], [1, 256, 24, 24]], "DepthPro_depth")
239+
save_mlpackage(model.depth.head, [[1, 256, 768, 768]], "DepthPro_head")
240+
241+
@register_torch_op()
242+
def _upsample_bicubic2d_aa(context, node):
243+
upsample_bilinear2d(context, node)
244+
245+
# https://github.com/apple/coremltools/pull/2354 CoreMLTools 8.0 fix
246+
@register_torch_op(torch_alias=["concat"], override=True)
247+
def cat(context, node):
248+
def is_tensor_empty(var: Var) -> bool:
249+
return np.any([size == 0 for size in var.shape])
250+
251+
def _parse_positional_args(context, node) -> Tuple[Var]:
252+
inputs = _get_inputs(context, node, min_expected=1)
253+
nargs = len(inputs)
254+
255+
xs = inputs[0]
256+
# PyTorch can have empty tensor, which is then ignored
257+
# However, CoreML does not allow such empty tensor, so remove them now
258+
if np.any([is_tensor_empty(x) for x in xs]):
259+
filtered_xs = [x for x in xs if not is_tensor_empty(x)]
260+
xs = filtered_xs if len(filtered_xs) > 0 else [xs[0]]
261+
262+
dim = inputs[1] if nargs > 1 else 0
263+
264+
return xs, dim
265+
266+
def _parse_keyword_args(context, node, dim) -> Var:
267+
# Only torch.export may have kwargs
268+
if context.frontend != TorchFrontend.TORCHEXPORT:
269+
return dim
270+
271+
dim = _get_kwinputs(context, node, "dim", default=[dim])[0]
272+
return dim
273+
274+
xs, dim = _parse_positional_args(context, node)
275+
dim = _parse_keyword_args(context, node, dim)
276+
277+
concat = mb.concat(values=promote_input_dtypes(xs), axis=dim, name=node.name)
278+
context.add(concat)
279+
280+
def save_mlpackage(G, shapes, name, image_type = False):
281+
G.eval()
282+
G_inputs = []
283+
convert_inputs = []
284+
for shape in shapes:
285+
G_inputs.append(torch.randn(shape))
286+
convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float32) if image_type == False else ct.ImageType(shape=shape, color_layout=ct.colorlayout.RGB))
287+
G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs])
288+
G_model = ct.convert(
289+
G_trace,
290+
inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs],
291+
minimum_deployment_target=ct.target.macOS15,
292+
compute_precision=ct.precision.FLOAT32,
293+
compute_units=ct.ComputeUnit.CPU_AND_NE
294+
)
295+
G_model.save("out/" + name + ".mlpackage")
296+
297+
if __name__ == "__main__":
298+
model = create_scaled_model()
299+
load_and_show_example(model)
300+
save_coreml_packages(model)

src/depth_pro/network/decoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Ten
169169

170170
if x1 is not None:
171171
res = self.resnet1(x1)
172+
_, _, Wx, Hx = x.shape
173+
_, _, Wres, Hres = res.shape
174+
if Wx != Wres or Hx != Hres:
175+
x = nn.functional.interpolate(x, size=(Wres, Hres), mode="bilinear", align_corners=False)
172176
x = self.skip_add.add(x, res)
173177

174178
x = self.resnet2(x)

src/depth_pro/network/encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _create_pyramid(
169169

170170
def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
171171
"""Split the input into small patches with sliding window."""
172-
patch_size = 384
172+
patch_size = self.patch_encoder.patch_embed.img_size[0]
173173
patch_stride = int(patch_size * (1 - overlap_ratio))
174174

175175
image_size = x.shape[-1]

src/depth_pro/network/vit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ def forward_features_eva_fixed(self, x):
4848
return x
4949

5050

51-
def resize_vit(model: nn.Module, img_size) -> nn.Module:
51+
def resize_vit(model: nn.Module, img_size, grid_size) -> nn.Module:
5252
"""Resample the ViT module to the given size."""
5353
patch_size = model.patch_embed.patch_size
5454
model.patch_embed.img_size = img_size
55-
grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
56-
model.patch_embed.grid_size = grid_size
55+
# grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
56+
# model.patch_embed.grid_size = grid_size
5757

5858
pos_embed = resample_abs_pos_embed(
5959
model.pos_embed,

src/depth_pro/network/vit_factory.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ViTConfig:
3737

3838
img_size: int = 384
3939
patch_size: int = 16
40+
grid_size: int = 24
4041

4142
# In case we need to rescale the backbone when loading from timm.
4243
timm_preset: Optional[str] = None
@@ -51,13 +52,26 @@ class ViTConfig:
5152

5253

5354
VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = {
55+
"dinov2l16_192": ViTConfig(
56+
in_chans=3,
57+
embed_dim=1024,
58+
encoder_feature_layer_ids=[5, 11, 17, 23],
59+
encoder_feature_dims=[256, 512, 1024, 1024],
60+
img_size=192,
61+
patch_size=16,
62+
grid_size=24,
63+
timm_preset="vit_large_patch14_dinov2",
64+
timm_img_size=518,
65+
timm_patch_size=14,
66+
),
5467
"dinov2l16_384": ViTConfig(
5568
in_chans=3,
5669
embed_dim=1024,
5770
encoder_feature_layer_ids=[5, 11, 17, 23],
5871
encoder_feature_dims=[256, 512, 1024, 1024],
5972
img_size=384,
6073
patch_size=16,
74+
grid_size=24,
6175
timm_preset="vit_large_patch14_dinov2",
6276
timm_img_size=518,
6377
timm_patch_size=14,
@@ -107,7 +121,7 @@ def create_vit(
107121
if config.patch_size != config.timm_patch_size:
108122
model.model = resize_patch_embed(model.model, new_patch_size=patch_size)
109123
if config.img_size != config.timm_img_size:
110-
model.model = resize_vit(model.model, img_size=img_size)
124+
model.model = resize_vit(model.model, img_size=img_size, grid_size=(config.grid_size, config.grid_size))
111125

112126
if checkpoint_uri is not None:
113127
state_dict = torch.load(checkpoint_uri, map_location="cpu")

0 commit comments

Comments
 (0)