99import warnings
1010from collections import OrderedDict
1111from dataclasses import dataclass , field
12- from typing import Any , Dict , List , Optional , Union
12+ from typing import Any , Dict , List , Optional , Sequence , Union
1313
1414import numpy as np
1515import torch
16+ import torch .nn .functional as F
1617from pytorch3d .implicitron .dataset .implicitron_dataset import FrameData
1718from pytorch3d .implicitron .dataset .utils import is_known_frame , is_train_frame
19+ from pytorch3d .implicitron .models .base_model import ImplicitronRender
1820from pytorch3d .implicitron .tools import vis_utils
1921from pytorch3d .implicitron .tools .camera_utils import volumetric_camera_overlaps
2022from pytorch3d .implicitron .tools .image_utils import mask_background
3133EVAL_N_SRC_VIEWS = [1 , 3 , 5 , 7 , 9 ]
3234
3335
34- @dataclass
35- class NewViewSynthesisPrediction :
36- """
37- Holds the tensors that describe a result of synthesizing new views.
38- """
39-
40- depth_render : Optional [torch .Tensor ] = None
41- image_render : Optional [torch .Tensor ] = None
42- mask_render : Optional [torch .Tensor ] = None
43- camera_distance : Optional [torch .Tensor ] = None
44-
45-
4636@dataclass
4737class _Visualizer :
4838 image_render : torch .Tensor
@@ -145,8 +135,8 @@ def show_depth(
145135
146136def eval_batch (
147137 frame_data : FrameData ,
148- nvs_prediction : NewViewSynthesisPrediction ,
149- bg_color : Union [torch .Tensor , str , float ] = "black" ,
138+ implicitron_render : ImplicitronRender ,
139+ bg_color : Union [torch .Tensor , Sequence , str , float ] = "black" ,
150140 mask_thr : float = 0.5 ,
151141 lpips_model = None ,
152142 visualize : bool = False ,
@@ -162,14 +152,14 @@ def eval_batch(
162152 is True), a new-view synthesis method (NVS) is tasked to generate new views
163153 of the scene from the viewpoint of the target views (for which
164154 frame_data.frame_type.endswith('known') is False). The resulting
165- synthesized new views, stored in `nvs_prediction `, are compared to the
155+ synthesized new views, stored in `implicitron_render `, are compared to the
166156 target ground truth in `frame_data` in terms of geometry and appearance
167157 resulting in a dictionary of metrics returned by the `eval_batch` function.
168158
169159 Args:
170160 frame_data: A FrameData object containing the input to the new view
171161 synthesis method.
172- nvs_prediction : The data describing the synthesized new views.
162+ implicitron_render : The data describing the synthesized new views.
173163 bg_color: The background color of the generated new views and the
174164 ground truth.
175165 lpips_model: A pre-trained model for evaluating the LPIPS metric.
@@ -184,26 +174,39 @@ def eval_batch(
184174 ValueError if frame_data does not have frame_type, camera, or image_rgb
185175 ValueError if the batch has a mix of training and test samples
186176 ValueError if the batch frames are not [unseen, known, known, ...]
187- ValueError if one of the required fields in nvs_prediction is missing
177+ ValueError if one of the required fields in implicitron_render is missing
188178 """
189- REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render" , "image_render" , "depth_render" ]
190179 frame_type = frame_data .frame_type
191180 if frame_type is None :
192181 raise ValueError ("Frame type has not been set." )
193182
194183 # we check that all those fields are not None but Pyre can't infer that properly
195- # TODO: assign to local variables
184+ # TODO: assign to local variables and simplify the code.
196185 if frame_data .image_rgb is None :
197186 raise ValueError ("Image is not in the evaluation batch." )
198187
199188 if frame_data .camera is None :
200189 raise ValueError ("Camera is not in the evaluation batch." )
201190
202- if any (not hasattr (nvs_prediction , k ) for k in REQUIRED_NVS_PREDICTION_FIELDS ):
203- raise ValueError ("One of the required predicted fields is missing" )
191+ # eval all results in the resolution of the frame_data image
192+ image_resol = tuple (frame_data .image_rgb .shape [2 :])
193+
194+ # Post-process the render:
195+ # 1) check implicitron_render for Nones,
196+ # 2) obtain copies to make sure we dont edit the original data,
197+ # 3) take only the 1st (target) image
198+ # 4) resize to match ground-truth resolution
199+ cloned_render : Dict [str , torch .Tensor ] = {}
200+ for k in ["mask_render" , "image_render" , "depth_render" ]:
201+ field = getattr (implicitron_render , k )
202+ if field is None :
203+ raise ValueError (f"A required predicted field { k } is missing" )
204+
205+ imode = "bilinear" if k == "image_render" else "nearest"
206+ cloned_render [k ] = (
207+ F .interpolate (field [:1 ], size = image_resol , mode = imode ).detach ().clone ()
208+ )
204209
205- # obtain copies to make sure we dont edit the original data
206- nvs_prediction = copy .deepcopy (nvs_prediction )
207210 frame_data = copy .deepcopy (frame_data )
208211
209212 # mask the ground truth depth in case frame_data contains the depth mask
@@ -226,9 +229,6 @@ def eval_batch(
226229 + " a target view while the rest should be source views."
227230 ) # TODO: do we need to enforce this?
228231
229- # take only the first (target image)
230- for k in REQUIRED_NVS_PREDICTION_FIELDS :
231- setattr (nvs_prediction , k , getattr (nvs_prediction , k )[:1 ])
232232 for k in [
233233 "depth_map" ,
234234 "image_rgb" ,
@@ -242,10 +242,6 @@ def eval_batch(
242242 if frame_data .depth_map is None or frame_data .depth_map .sum () <= 0 :
243243 warnings .warn ("Empty or missing depth map in evaluation!" )
244244
245- # eval all results in the resolution of the frame_data image
246- # pyre-fixme[16]: `Optional` has no attribute `shape`.
247- image_resol = list (frame_data .image_rgb .shape [2 :])
248-
249245 # threshold the masks to make ground truth binary masks
250246 mask_fg , mask_crop = [
251247 (getattr (frame_data , k ) >= mask_thr ) for k in ("fg_probability" , "mask_crop" )
@@ -258,29 +254,14 @@ def eval_batch(
258254 bg_color = bg_color ,
259255 )
260256
261- # resize to the target resolution
262- for k in REQUIRED_NVS_PREDICTION_FIELDS :
263- imode = "bilinear" if k == "image_render" else "nearest"
264- val = getattr (nvs_prediction , k )
265- setattr (
266- nvs_prediction ,
267- k ,
268- # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
269- # `List[typing.Any]`.
270- torch .nn .functional .interpolate (val , size = image_resol , mode = imode ),
271- )
272-
273257 # clamp predicted images
274- # pyre-fixme[16]: `Optional` has no attribute `clamp`.
275- image_render = nvs_prediction .image_render .clamp (0.0 , 1.0 )
258+ image_render = cloned_render ["image_render" ].clamp (0.0 , 1.0 )
276259
277260 if visualize :
278261 visualizer = _Visualizer (
279262 image_render = image_render ,
280263 image_rgb_masked = image_rgb_masked ,
281- # pyre-fixme[6]: Expected `Tensor` for 3rd param but got
282- # `Optional[torch.Tensor]`.
283- depth_render = nvs_prediction .depth_render ,
264+ depth_render = cloned_render ["depth_render" ],
284265 # pyre-fixme[6]: Expected `Tensor` for 4th param but got
285266 # `Optional[torch.Tensor]`.
286267 depth_map = frame_data .depth_map ,
@@ -292,9 +273,7 @@ def eval_batch(
292273 results : Dict [str , Any ] = {}
293274
294275 results ["iou" ] = iou (
295- # pyre-fixme[6]: Expected `Tensor` for 1st param but got
296- # `Optional[torch.Tensor]`.
297- nvs_prediction .mask_render ,
276+ cloned_render ["mask_render" ],
298277 mask_fg ,
299278 mask = mask_crop ,
300279 )
@@ -321,11 +300,7 @@ def eval_batch(
321300 if name_postfix == "_fg" :
322301 # only record depth metrics for the foreground
323302 _ , abs_ = eval_depth (
324- # pyre-fixme[6]: Expected `Tensor` for 1st param but got
325- # `Optional[torch.Tensor]`.
326- nvs_prediction .depth_render ,
327- # pyre-fixme[6]: Expected `Tensor` for 2nd param but got
328- # `Optional[torch.Tensor]`.
303+ cloned_render ["depth_render" ],
329304 frame_data .depth_map ,
330305 get_best_scale = True ,
331306 mask = loss_mask_now ,
@@ -343,7 +318,7 @@ def eval_batch(
343318 if lpips_model is not None :
344319 im1 , im2 = [
345320 2.0 * im .clamp (0.0 , 1.0 ) - 1.0
346- for im in (image_rgb_masked , nvs_prediction . image_render )
321+ for im in (image_rgb_masked , cloned_render [ " image_render" ] )
347322 ]
348323 results ["lpips" ] = lpips_model .forward (im1 , im2 ).item ()
349324
0 commit comments