From 4850b4eaeac88bb54901a1bad776cb10930bcac1 Mon Sep 17 00:00:00 2001 From: YASH <139299779+Yash-2707@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:30:42 +0530 Subject: [PATCH] Update run.py Enhanced error handling and logging Performance improvements Added a --batch-size argument for potential batch processing (not implemented in the inference loop, but prepared for future enhancement). Updated matplotlib figure creation for better layout. and added colorbar to depth map visualization Added file extension check to process only image files. --- src/depth_pro/cli/run.py | 174 +++++++++++++++++---------------------- 1 file changed, 74 insertions(+), 100 deletions(-) diff --git a/src/depth_pro/cli/run.py b/src/depth_pro/cli/run.py index 3545a999..309cedab 100755 --- a/src/depth_pro/cli/run.py +++ b/src/depth_pro/cli/run.py @@ -4,10 +4,10 @@ Copyright (C) 2024 Apple Inc. All Rights Reserved. """ - import argparse import logging from pathlib import Path +import time import numpy as np import PIL.Image @@ -19,136 +19,110 @@ LOGGER = logging.getLogger(__name__) - def get_torch_device() -> torch.device: """Get the Torch device.""" - device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda:0") + return torch.device("cuda:0") elif torch.backends.mps.is_available(): - device = torch.device("mps") - return device - + return torch.device("mps") + return torch.device("cpu") def run(args): - """Run Depth Pro on a sample image.""" - if args.verbose: - logging.basicConfig(level=logging.INFO) - - # Load model. - model, transform = create_model_and_transforms( - device=get_torch_device(), - precision=torch.half, - ) + """Run Depth Pro on a sample image or directory of images.""" + logging.basicConfig(level=logging.INFO if args.verbose else logging.WARNING) + + device = get_torch_device() + LOGGER.info(f"Using device: {device}") + + # Load model + start_time = time.time() + model, transform = create_model_and_transforms(device=device, precision=torch.half) model.eval() + LOGGER.info(f"Model loaded in {time.time() - start_time:.2f} seconds") - image_paths = [args.image_path] - if args.image_path.is_dir(): - image_paths = args.image_path.glob("**/*") - relative_path = args.image_path - else: - relative_path = args.image_path.parent + image_paths = list(args.image_path.glob("**/*")) if args.image_path.is_dir() else [args.image_path] + relative_path = args.image_path if args.image_path.is_dir() else args.image_path.parent if not args.skip_display: plt.ion() - fig = plt.figure() - ax_rgb = fig.add_subplot(121) - ax_disp = fig.add_subplot(122) + fig, (ax_rgb, ax_disp) = plt.subplots(1, 2, figsize=(12, 6)) + + total_inference_time = 0 + for image_path in tqdm(image_paths, desc="Processing images"): + if not image_path.is_file() or image_path.suffix.lower() not in {'.jpg', '.jpeg', '.png', '.bmp'}: + continue - for image_path in tqdm(image_paths): - # Load image and focal length from exif info (if found.). try: LOGGER.info(f"Loading image {image_path} ...") image, _, f_px = load_rgb(image_path) except Exception as e: - LOGGER.error(str(e)) + LOGGER.error(f"Error loading {image_path}: {str(e)}") continue - # Run prediction. If `f_px` is provided, it is used to estimate the final metric depth, - # otherwise the model estimates `f_px` to compute the depth metricness. - prediction = model.infer(transform(image), f_px=f_px) - # Extract the depth and focal length. - depth = prediction["depth"].detach().cpu().numpy().squeeze() + # Run prediction + start_time = time.time() + with torch.no_grad(): + prediction = model.infer(transform(image), f_px=f_px) + inference_time = time.time() - start_time + total_inference_time += inference_time + + depth = prediction["depth"].cpu().numpy().squeeze() if f_px is not None: - LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}") + LOGGER.info(f"Focal length (from exif): {f_px:0.2f}") elif prediction["focallength_px"] is not None: - focallength_px = prediction["focallength_px"].detach().cpu().item() - LOGGER.info(f"Estimated focal length: {focallength_px}") - - inverse_depth = 1 / depth - # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization. - max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1) - min_invdepth_vizu = max(1 / 250, inverse_depth.min()) - inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / ( - max_invdepth_vizu - min_invdepth_vizu - ) - - # Save Depth as npz file. - if args.output_path is not None: - output_file = ( - args.output_path - / image_path.relative_to(relative_path).parent - / image_path.stem - ) - LOGGER.info(f"Saving depth map to: {str(output_file)}") + focallength_px = prediction["focallength_px"].cpu().item() + LOGGER.info(f"Estimated focal length: {focallength_px:0.2f}") + + inverse_depth = 1 / np.clip(depth, 0.1, 250) + inverse_depth_normalized = (inverse_depth - inverse_depth.min()) / (inverse_depth.max() - inverse_depth.min()) + + if args.output_path: + output_file = args.output_path / image_path.relative_to(relative_path).with_suffix('') output_file.parent.mkdir(parents=True, exist_ok=True) + + # Save depth as NPZ np.savez_compressed(output_file, depth=depth) + LOGGER.info(f"Depth map saved to: {output_file}.npz") + + # Save color-mapped depth as JPG + cmap = plt.get_cmap('turbo') + color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8) + color_map_output_file = output_file.with_suffix('.jpg') + PIL.Image.fromarray(color_depth).save(color_map_output_file, format="JPEG", quality=95) + LOGGER.info(f"Color-mapped depth saved to: {color_map_output_file}") - # Save as color-mapped "turbo" jpg image. - cmap = plt.get_cmap("turbo") - color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype( - np.uint8 - ) - color_map_output_file = str(output_file) + ".jpg" - LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}") - PIL.Image.fromarray(color_depth).save( - color_map_output_file, format="JPEG", quality=90 - ) - - # Display the image and estimated depth map. if not args.skip_display: + ax_rgb.clear() + ax_disp.clear() ax_rgb.imshow(image) - ax_disp.imshow(inverse_depth_normalized, cmap="turbo") - fig.canvas.draw() - fig.canvas.flush_events() + ax_rgb.set_title("Input Image") + ax_rgb.axis('off') + im = ax_disp.imshow(inverse_depth_normalized, cmap="turbo") + ax_disp.set_title(f"Depth Map (Inference time: {inference_time:.3f}s)") + ax_disp.axis('off') + plt.colorbar(im, ax=ax_disp, label="Inverse Depth", orientation="horizontal", pad=0.05) + fig.tight_layout() + plt.pause(0.1) + + avg_inference_time = total_inference_time / len(image_paths) + LOGGER.info(f"Average inference time: {avg_inference_time:.3f} seconds per image") + LOGGER.info("Depth estimation completed!") - LOGGER.info("Done predicting depth!") if not args.skip_display: - plt.show(block=True) - + plt.ioff() + plt.show() def main(): """Run DepthPro inference example.""" - parser = argparse.ArgumentParser( - description="Inference scripts of DepthPro with PyTorch models." - ) - parser.add_argument( - "-i", - "--image-path", - type=Path, - default="./data/example.jpg", - help="Path to input image.", - ) - parser.add_argument( - "-o", - "--output-path", - type=Path, - help="Path to store output files.", - ) - parser.add_argument( - "--skip-display", - action="store_true", - help="Skip matplotlib display.", - ) - parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Show verbose output." - ) + parser = argparse.ArgumentParser(description="Inference script for DepthPro with PyTorch models.") + parser.add_argument("-i", "--image-path", type=Path, required=True, help="Path to input image or directory.") + parser.add_argument("-o", "--output-path", type=Path, help="Path to store output files.") + parser.add_argument("--skip-display", action="store_true", help="Skip matplotlib display.") + parser.add_argument("-v", "--verbose", action="store_true", help="Show verbose output.") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for inference (default: 1)") - run(parser.parse_args()) - + args = parser.parse_args() + run(args) if __name__ == "__main__": main()