@@ -74,8 +74,6 @@ def load_models(self, model_type, device: torch.device, boost: bool):
7474 model_dir = "./models/midas"
7575 if model_type == 0 :
7676 model_dir = "./models/leres"
77- if model_type == 10 :
78- "./models/marigold"
7977 # create paths to model if not present
8078 os .makedirs (model_dir , exist_ok = True )
8179 os .makedirs ('./models/pix2pix' , exist_ok = True )
@@ -197,9 +195,9 @@ def load_models(self, model_type, device: torch.device, boost: bool):
197195 model = build_model (conf )
198196
199197 elif model_type == 10 : # Marigold v1
200- # TODO: pass more parameters
201- model_path = f" { model_dir } /marigold_v1/"
202- from repositories . Marigold .src .model .marigold_pipeline import MarigoldPipeline
198+ model_path = "Bingxin/Marigold"
199+ print ( model_path )
200+ from Marigold .src .model .marigold_pipeline import MarigoldPipeline
203201 model = MarigoldPipeline .from_pretrained (model_path )
204202
205203 model .eval () # prepare for evaluation
@@ -301,11 +299,11 @@ def get_raw_prediction(self, input, net_width, net_height):
301299 self .resize_mode , self .normalization , self .no_half ,
302300 self .precision == "autocast" )
303301 elif self .depth_model_type == 10 :
304- raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height , self . device )
302+ raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height )
305303 else :
306304 raw_prediction = estimateboost (img , self .depth_model , self .depth_model_type , self .pix2pix_model ,
307305 self .boost_whole_size_threshold )
308- raw_prediction_invert = self .depth_model_type in [0 , 7 , 8 , 9 ]
306+ raw_prediction_invert = self .depth_model_type in [0 , 7 , 8 , 9 , 10 ]
309307 return raw_prediction , raw_prediction_invert
310308
311309
@@ -405,11 +403,11 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
405403 return prediction
406404
407405
408- def estimatemarigold (image , model , w , h , device ):
409- from repositories . Marigold .src .model .marigold_pipeline import MarigoldPipeline
410- from repositories . Marigold .src .util .ensemble import ensemble_depths
411- from repositories . Marigold .src .util .image_util import chw2hwc , colorize_depth_maps , resize_max_res
412- from repositories . Marigold .src .util .seed_all import seed_all
406+ def estimatemarigold (image , model , w , h ):
407+ from Marigold .src .model .marigold_pipeline import MarigoldPipeline
408+ from Marigold .src .util .ensemble import ensemble_depths
409+ from Marigold .src .util .image_util import chw2hwc , colorize_depth_maps , resize_max_res
410+ from Marigold .src .util .seed_all import seed_all
413411
414412 n_repeat = 10
415413 denoise_steps = 10
@@ -418,13 +416,18 @@ def estimatemarigold(image, model, w, h, device):
418416 tol = 1e-3
419417 reduction_method = "median"
420418 merging_max_res = None
419+ resize_to_max_res = None
421420
422421 # From Marigold repository run.py
423422 with torch .no_grad ():
424- rgb = np .transpose (image , (2 , 0 , 1 )) # [H, W, rgb] -> [rgb, H, W]
425- rgb_norm = rgb / 255.0
423+ if resize_to_max_res is not None :
424+ image = (image * 255 ).astype (np .uint8 )
425+ image = np .asarray (resize_max_res (
426+ Image .fromarray (image ), max_edge_resolution = resize_to_max_res
427+ )) / 255.0
428+ rgb_norm = np .transpose (image , (2 , 0 , 1 )) # [H, W, rgb] -> [rgb, H, W]
426429 rgb_norm = torch .from_numpy (rgb_norm ).unsqueeze (0 ).float ()
427- rgb_norm = rgb_norm .to (device )
430+ rgb_norm = rgb_norm .to (depthmap_device )
428431
429432 model .unet .eval ()
430433 depth_pred_ls = []
@@ -445,7 +448,7 @@ def estimatemarigold(image, model, w, h, device):
445448 tol = tol ,
446449 reduction = reduction_method ,
447450 max_res = merging_max_res ,
448- device = device ,
451+ device = depthmap_device ,
449452 )
450453 else :
451454 depth_pred = depth_preds
@@ -942,6 +945,8 @@ def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel
942945def singleestimate (img , msize , model , net_type ):
943946 if net_type == 0 :
944947 return estimateleres (img , model , msize , msize )
948+ elif net_type == 10 :
949+ return estimatemarigold (img , model , msize , msize )
945950 elif net_type >= 7 :
946951 # np to PIL
947952 return estimatezoedepth (Image .fromarray (np .uint8 (img * 255 )).convert ('RGB' ), model , msize , msize )
0 commit comments