@@ -74,6 +74,8 @@ def load_models(self, model_type, device: torch.device, boost: bool):
7474 model_dir = "./models/midas"
7575 if model_type == 0 :
7676 model_dir = "./models/leres"
77+ if model_type == 10 :
78+ "./models/marigold"
7779 # create paths to model if not present
7880 os .makedirs (model_dir , exist_ok = True )
7981 os .makedirs ('./models/pix2pix' , exist_ok = True )
@@ -194,6 +196,12 @@ def load_models(self, model_type, device: torch.device, boost: bool):
194196 conf = get_config ("zoedepth_nk" , "infer" )
195197 model = build_model (conf )
196198
199+ elif model_type == 10 : # Marigold v1
200+ # TODO: pass more parameters
201+ model_path = f"{ model_dir } /marigold_v1/"
202+ from repositories .Marigold .src .model .marigold_pipeline import MarigoldPipeline
203+ model = MarigoldPipeline .from_pretrained (model_path )
204+
197205 model .eval () # prepare for evaluation
198206 # optimize
199207 if device == torch .device ("cuda" ) and model_type in [0 , 1 , 2 , 3 , 4 , 5 , 6 ]:
@@ -288,10 +296,12 @@ def get_raw_prediction(self, input, net_width, net_height):
288296 raw_prediction = estimateleres (img , self .depth_model , net_width , net_height )
289297 elif self .depth_model_type in [7 , 8 , 9 ]:
290298 raw_prediction = estimatezoedepth (input , self .depth_model , net_width , net_height )
291- else :
299+ elif self . depth_model_type in [ 1 , 2 , 3 , 4 , 5 , 6 ] :
292300 raw_prediction = estimatemidas (img , self .depth_model , net_width , net_height ,
293301 self .resize_mode , self .normalization , self .no_half ,
294302 self .precision == "autocast" )
303+ elif self .depth_model_type == 10 :
304+ raw_prediction = estimatemarigold (img , self .depth_model , net_width , net_height , self .device )
295305 else :
296306 raw_prediction = estimateboost (img , self .depth_model , self .depth_model_type , self .pix2pix_model ,
297307 self .boost_whole_size_threshold )
@@ -395,6 +405,52 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
395405 return prediction
396406
397407
408+ def estimatemarigold (image , model , w , h , device ):
409+ from repositories .Marigold .src .model .marigold_pipeline import MarigoldPipeline
410+ from repositories .Marigold .src .util .ensemble import ensemble_depths
411+ from repositories .Marigold .src .util .image_util import chw2hwc , colorize_depth_maps , resize_max_res
412+ from repositories .Marigold .src .util .seed_all import seed_all
413+
414+ n_repeat = 10
415+ denoise_steps = 10
416+ regularizer_strength = 0.02
417+ max_iter = 5
418+ tol = 1e-3
419+ reduction_method = "median"
420+ merging_max_res = None
421+
422+ # From Marigold repository run.py
423+ with torch .no_grad ():
424+ rgb = np .transpose (image , (2 , 0 , 1 )) # [H, W, rgb] -> [rgb, H, W]
425+ rgb_norm = rgb / 255.0
426+ rgb_norm = torch .from_numpy (rgb_norm ).unsqueeze (0 ).float ()
427+ rgb_norm = rgb_norm .to (device )
428+
429+ model .unet .eval ()
430+ depth_pred_ls = []
431+ for i_rep in range (n_repeat ):
432+ depth_pred_raw = model .forward (
433+ rgb_norm , num_inference_steps = denoise_steps , init_depth_latent = None
434+ )
435+ # clip prediction
436+ depth_pred_raw = torch .clip (depth_pred_raw , - 1.0 , 1.0 )
437+ depth_pred_ls .append (depth_pred_raw .detach ().cpu ().numpy ().copy ())
438+
439+ depth_preds = np .concatenate (depth_pred_ls , axis = 0 ).squeeze ()
440+ if n_repeat > 1 :
441+ depth_pred , pred_uncert = ensemble_depths (
442+ depth_preds ,
443+ regularizer_strength = regularizer_strength ,
444+ max_iter = max_iter ,
445+ tol = tol ,
446+ reduction = reduction_method ,
447+ max_res = merging_max_res ,
448+ device = device ,
449+ )
450+ else :
451+ depth_pred = depth_preds
452+ return depth_pred
453+
398454class ImageandPatchs :
399455 def __init__ (self , root_dir , name , patchsinfo , rgb_image , scale = 1 ):
400456 self .root_dir = root_dir
0 commit comments