|
| 1 | +import logging |
| 2 | +import math |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +import coremltools as ct |
| 6 | +from coremltools.converters.mil import register_torch_op |
| 7 | +from coremltools.converters.mil.frontend.torch.ops import upsample_bilinear2d |
| 8 | +from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import nn |
| 12 | +from torch.nn import functional as F |
| 13 | + |
| 14 | +from matplotlib import pyplot as plt |
| 15 | +from typing import Tuple |
| 16 | + |
| 17 | +from src.depth_pro.depth_pro import ( |
| 18 | + create_model_and_transforms, |
| 19 | + create_backbone_model, |
| 20 | + DepthProConfig |
| 21 | +) |
| 22 | +from src.depth_pro.network.decoder import MultiresConvDecoder |
| 23 | +from src.depth_pro.network.encoder import DepthProEncoder |
| 24 | +from src.depth_pro.network.fov import FOVNetwork |
| 25 | +from src.depth_pro.network.vit import resize_vit, resize_patch_embed |
| 26 | +from src.depth_pro.utils import load_rgb |
| 27 | + |
| 28 | +from torchvision.transforms import ( |
| 29 | + Compose, |
| 30 | + ConvertImageDtype, |
| 31 | + Lambda, |
| 32 | + Normalize, |
| 33 | + ToTensor |
| 34 | +) |
| 35 | + |
| 36 | +""" |
| 37 | +example.jpg fov_deg = |
| 38 | +default 1536x1536: 48.4297 |
| 39 | +scaled 1024x1024: 49.8382 |
| 40 | +""" |
| 41 | + |
| 42 | +class DepthDecoder(nn.Module): |
| 43 | + def __init__(self, head: nn.Module, fov: FOVNetwork): |
| 44 | + super(DepthDecoder, self).__init__() |
| 45 | + self.head = head |
| 46 | + self.fov = fov |
| 47 | + |
| 48 | + def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| 49 | + x = inputs[0] |
| 50 | + features = inputs[1] |
| 51 | + features_0 = inputs[2] |
| 52 | + |
| 53 | + # execute fov.forward locally with a different scale_factor |
| 54 | + # fov_deg = self.fov.forward(x, features_0.detach()) |
| 55 | + if hasattr(self.fov, "encoder"): |
| 56 | + x = F.interpolate( |
| 57 | + x, |
| 58 | + size=None, |
| 59 | + # result size needs to be 384 |
| 60 | + scale_factor=0.25, |
| 61 | + mode="bilinear", |
| 62 | + align_corners=False, |
| 63 | + ) |
| 64 | + x = self.fov.encoder(x)[:, 1:].permute(0, 2, 1) |
| 65 | + lowres_feature = self.fov.downsample(features_0.detach()) |
| 66 | + x = x.reshape_as(lowres_feature) + lowres_feature |
| 67 | + else: |
| 68 | + x = features_0.detach() |
| 69 | + |
| 70 | + fov_deg = self.fov.head(x) |
| 71 | + f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0) |
| 72 | + |
| 73 | + canonical_inverse_depth = self.head(features) |
| 74 | + inverse_depth = canonical_inverse_depth * f_px |
| 75 | + depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4) |
| 76 | + return depth |
| 77 | + |
| 78 | +class DepthProScaled(nn.Module): |
| 79 | + def __init__(self, transform: nn.Module, encoder: DepthProEncoder, decoder: MultiresConvDecoder, depth: DepthDecoder): |
| 80 | + super().__init__() |
| 81 | + self.transform = transform |
| 82 | + self.encoder = encoder |
| 83 | + self.decoder = decoder |
| 84 | + self.depth = depth |
| 85 | + |
| 86 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 87 | + if x.shape[0] == 3: |
| 88 | + x = x.unsqueeze(0) |
| 89 | + image = self.transform(x) |
| 90 | + encodings = self.encoder(image) |
| 91 | + features, features_0 = self.decoder(encodings) |
| 92 | + depth = self.depth([image, features, features_0]) |
| 93 | + return depth |
| 94 | + |
| 95 | +class Interpolate(nn.Module): |
| 96 | + def __init__(self, size, mode): |
| 97 | + super(Interpolate, self).__init__() |
| 98 | + self.size = size |
| 99 | + self.mode = mode |
| 100 | + |
| 101 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 102 | + x = F.interpolate(x, size=self.size, mode=self.mode, align_corners=False) |
| 103 | + return x |
| 104 | + |
| 105 | +def create_scaled_model() -> DepthProScaled: |
| 106 | + config = DepthProConfig( |
| 107 | + patch_encoder_preset="dinov2l16_192", |
| 108 | + image_encoder_preset="dinov2l16_192", |
| 109 | + checkpoint_uri="./checkpoints/depth_pro.pt", |
| 110 | + decoder_features=256, |
| 111 | + use_fov_head=True, |
| 112 | + fov_encoder_preset="dinov2l16_192", |
| 113 | + ) |
| 114 | + |
| 115 | + patch_encoder, patch_encoder_config = create_backbone_model(preset = config.patch_encoder_preset) |
| 116 | + image_encoder, _ = create_backbone_model(preset = config.image_encoder_preset) |
| 117 | + fov_encoder, _ = create_backbone_model(preset = config.fov_encoder_preset) |
| 118 | + #fov_encoder = None |
| 119 | + |
| 120 | + dims_encoder = patch_encoder_config.encoder_feature_dims |
| 121 | + hook_block_ids = patch_encoder_config.encoder_feature_layer_ids |
| 122 | + encoder = DepthProEncoder( |
| 123 | + dims_encoder=dims_encoder, |
| 124 | + patch_encoder=patch_encoder, |
| 125 | + image_encoder=image_encoder, |
| 126 | + hook_block_ids=hook_block_ids, |
| 127 | + decoder_features=config.decoder_features, |
| 128 | + ) |
| 129 | + |
| 130 | + decoder = MultiresConvDecoder( |
| 131 | + dims_encoder=[config.decoder_features] + list(encoder.dims_encoder), |
| 132 | + dim_decoder=config.decoder_features, |
| 133 | + ) |
| 134 | + |
| 135 | + num_features = config.decoder_features |
| 136 | + fov = FOVNetwork(num_features=num_features, fov_encoder=fov_encoder) |
| 137 | + # Create FOV head. |
| 138 | + fov_head0 = [ |
| 139 | + nn.Conv2d( |
| 140 | + num_features, num_features // 2, kernel_size=3, stride=2, padding=3 |
| 141 | + ), # 128 x 24 x 24 |
| 142 | + nn.ReLU(True), |
| 143 | + ] |
| 144 | + fov_head = [ |
| 145 | + nn.Conv2d( |
| 146 | + num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=3 |
| 147 | + ), # 64 x 12 x 12 |
| 148 | + nn.ReLU(True), |
| 149 | + nn.Conv2d( |
| 150 | + num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=3 |
| 151 | + ), # 32 x 6 x 6 |
| 152 | + nn.ReLU(True), |
| 153 | + nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0), |
| 154 | + ] |
| 155 | + if fov_encoder is None: |
| 156 | + fov_head = fov_head0 + fov_head |
| 157 | + fov.head = nn.Sequential(*fov_head) |
| 158 | + #fov = None |
| 159 | + |
| 160 | + last_dims = (32, 1) |
| 161 | + dim_decoder = config.decoder_features |
| 162 | + head = nn.Sequential( |
| 163 | + nn.Conv2d( |
| 164 | + dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1 |
| 165 | + ), |
| 166 | + nn.ConvTranspose2d( |
| 167 | + in_channels=dim_decoder // 2, |
| 168 | + out_channels=dim_decoder // 2, |
| 169 | + kernel_size=2, |
| 170 | + stride=2, |
| 171 | + padding=0, |
| 172 | + bias=True, |
| 173 | + ), |
| 174 | + nn.Conv2d( |
| 175 | + dim_decoder // 2, |
| 176 | + last_dims[0], |
| 177 | + kernel_size=3, |
| 178 | + stride=1, |
| 179 | + padding=1, |
| 180 | + ), |
| 181 | + nn.ReLU(True), |
| 182 | + nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0), |
| 183 | + nn.ReLU(), |
| 184 | + ) |
| 185 | + |
| 186 | + # Set the final convolution layer's bias to be 0. |
| 187 | + head[4].bias.data.fill_(0) |
| 188 | + |
| 189 | + # from depth_pro.py |
| 190 | + transform = nn.Sequential( |
| 191 | + #[ |
| 192 | + #ToTensor(), |
| 193 | + #Lambda(lambda x: x.to(device)), |
| 194 | + Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| 195 | + #Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]), |
| 196 | + Interpolate( |
| 197 | + size=(encoder.img_size, encoder.img_size), |
| 198 | + mode="bilinear" |
| 199 | + ), |
| 200 | + ConvertImageDtype(torch.float32), |
| 201 | + #] |
| 202 | + ) |
| 203 | + |
| 204 | + depth = DepthDecoder(head, fov) |
| 205 | + load_state_dict(depth, config) |
| 206 | + |
| 207 | + model = DepthProScaled(transform, encoder, decoder, depth) |
| 208 | + load_state_dict(model, config) |
| 209 | + |
| 210 | + return model |
| 211 | + |
| 212 | +def load_state_dict(model: nn.Module, config: DepthProConfig): |
| 213 | + checkpoint_uri = config.checkpoint_uri |
| 214 | + state_dict = torch.load(checkpoint_uri, map_location="cpu") |
| 215 | + _, _ = model.load_state_dict( |
| 216 | + state_dict=state_dict, strict=False |
| 217 | + ) |
| 218 | + |
| 219 | +def load_and_show_example(model: DepthProScaled): |
| 220 | + image, _, _ = load_rgb("data/example.jpg") |
| 221 | + model_run = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), model]) |
| 222 | + depth_map = model_run(image).detach().cpu().numpy().squeeze() |
| 223 | + |
| 224 | + plt.ion() |
| 225 | + fig = plt.figure() |
| 226 | + ax_rgb = fig.add_subplot(121) |
| 227 | + ax_disp = fig.add_subplot(122) |
| 228 | + ax_rgb.imshow(image) |
| 229 | + ax_disp.imshow(depth_map, cmap="turbo") |
| 230 | + fig.canvas.draw() |
| 231 | + fig.canvas.flush_events() |
| 232 | + plt.show(block=True) |
| 233 | + |
| 234 | +def save_coreml_packages(model: DepthProScaled): |
| 235 | + save_mlpackage(model.transform, [[1, 3, 1080, 1920]], "DepthPro_transform", True) |
| 236 | + save_mlpackage(model.encoder, [[1, 3, 768, 768]], "DepthPro_encoder") |
| 237 | + save_mlpackage(model.decoder, [[1, 256, 288, 288], [1, 256, 144, 144], [1, 512, 72, 72], [1, 1024, 24, 24], [1, 1024, 24, 24]], "DepthPro_decoder") |
| 238 | + save_mlpackage(model.depth, [[1, 3, 768, 768], [1, 256, 288, 288], [1, 256, 24, 24]], "DepthPro_depth") |
| 239 | + save_mlpackage(model.depth.head, [[1, 256, 768, 768]], "DepthPro_head") |
| 240 | + |
| 241 | +@register_torch_op() |
| 242 | +def _upsample_bicubic2d_aa(context, node): |
| 243 | + upsample_bilinear2d(context, node) |
| 244 | + |
| 245 | +# https://github.com/apple/coremltools/pull/2354 CoreMLTools 8.0 fix |
| 246 | +@register_torch_op(torch_alias=["concat"], override=True) |
| 247 | +def cat(context, node): |
| 248 | + def is_tensor_empty(var: Var) -> bool: |
| 249 | + return np.any([size == 0 for size in var.shape]) |
| 250 | + |
| 251 | + def _parse_positional_args(context, node) -> Tuple[Var]: |
| 252 | + inputs = _get_inputs(context, node, min_expected=1) |
| 253 | + nargs = len(inputs) |
| 254 | + |
| 255 | + xs = inputs[0] |
| 256 | + # PyTorch can have empty tensor, which is then ignored |
| 257 | + # However, CoreML does not allow such empty tensor, so remove them now |
| 258 | + if np.any([is_tensor_empty(x) for x in xs]): |
| 259 | + filtered_xs = [x for x in xs if not is_tensor_empty(x)] |
| 260 | + xs = filtered_xs if len(filtered_xs) > 0 else [xs[0]] |
| 261 | + |
| 262 | + dim = inputs[1] if nargs > 1 else 0 |
| 263 | + |
| 264 | + return xs, dim |
| 265 | + |
| 266 | + def _parse_keyword_args(context, node, dim) -> Var: |
| 267 | + # Only torch.export may have kwargs |
| 268 | + if context.frontend != TorchFrontend.TORCHEXPORT: |
| 269 | + return dim |
| 270 | + |
| 271 | + dim = _get_kwinputs(context, node, "dim", default=[dim])[0] |
| 272 | + return dim |
| 273 | + |
| 274 | + xs, dim = _parse_positional_args(context, node) |
| 275 | + dim = _parse_keyword_args(context, node, dim) |
| 276 | + |
| 277 | + concat = mb.concat(values=promote_input_dtypes(xs), axis=dim, name=node.name) |
| 278 | + context.add(concat) |
| 279 | + |
| 280 | +def save_mlpackage(G, shapes, name, image_type = False): |
| 281 | + G.eval() |
| 282 | + G_inputs = [] |
| 283 | + convert_inputs = [] |
| 284 | + for shape in shapes: |
| 285 | + G_inputs.append(torch.randn(shape)) |
| 286 | + convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float32) if image_type == False else ct.ImageType(shape=shape, color_layout=ct.colorlayout.RGB)) |
| 287 | + G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs]) |
| 288 | + G_model = ct.convert( |
| 289 | + G_trace, |
| 290 | + inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs], |
| 291 | + minimum_deployment_target=ct.target.macOS15, |
| 292 | + compute_precision=ct.precision.FLOAT32, |
| 293 | + compute_units=ct.ComputeUnit.CPU_AND_NE |
| 294 | + ) |
| 295 | + G_model.save("out/" + name + ".mlpackage") |
| 296 | + |
| 297 | +if __name__ == "__main__": |
| 298 | + model = create_scaled_model() |
| 299 | + load_and_show_example(model) |
| 300 | + save_coreml_packages(model) |
0 commit comments