2020
2121
2222class ObjectDetectionEval (nn .Module ):
23- def forward (
24- self , img : Tensor , target : Optional [Dict [str , Tensor ]] = None
25- ) -> Tuple [Tensor , Optional [Dict [str , Tensor ]]]:
23+ def forward (self , img : Tensor ) -> Tensor :
2624 if not isinstance (img , Tensor ):
2725 img = F .pil_to_tensor (img )
28- return F .convert_image_dtype (img , torch .float ), target
26+ return F .convert_image_dtype (img , torch .float )
2927
3028
3129class ImageClassificationEval (nn .Module ):
@@ -95,28 +93,22 @@ def __init__(
9593 self ._interpolation = interpolation
9694 self ._interpolation_target = interpolation_target
9795
98- def forward (self , img : Tensor , target : Optional [ Tensor ] = None ) -> Tuple [ Tensor , Optional [ Tensor ]] :
96+ def forward (self , img : Tensor ) -> Tensor :
9997 if isinstance (self ._size , list ):
10098 img = F .resize (img , self ._size , interpolation = self ._interpolation )
10199 if not isinstance (img , Tensor ):
102100 img = F .pil_to_tensor (img )
103101 img = F .convert_image_dtype (img , torch .float )
104102 img = F .normalize (img , mean = self ._mean , std = self ._std )
105- if target :
106- if isinstance (self ._size , list ):
107- target = F .resize (target , self ._size , interpolation = self ._interpolation_target )
108- if not isinstance (target , Tensor ):
109- target = F .pil_to_tensor (target )
110- target = target .squeeze (0 ).to (torch .int64 )
111- return img , target
103+ return img
112104
113105
114106class OpticalFlowEval (nn .Module ):
115- def forward (
116- self , img1 : Tensor , img2 : Tensor , flow : Optional [ Tensor ] = None , valid_flow_mask : Optional [ Tensor ] = None
117- ) -> Tuple [ Tensor , Tensor , Optional [ Tensor ], Optional [ Tensor ]]:
118-
119- img1 , img2 , flow , valid_flow_mask = self . _pil_or_numpy_to_tensor ( img1 , img2 , flow , valid_flow_mask )
107+ def forward (self , img1 : Tensor , img2 : Tensor ) -> Tuple [ Tensor , Tensor ]:
108+ if not isinstance ( img1 , Tensor ):
109+ img1 = F . pil_to_tensor ( img1 )
110+ if not isinstance ( img2 , Tensor ):
111+ img2 = F . pil_to_tensor ( img2 )
120112
121113 img1 = F .convert_image_dtype (img1 , torch .float32 )
122114 img2 = F .convert_image_dtype (img2 , torch .float32 )
@@ -128,19 +120,4 @@ def forward(
128120 img1 = img1 .contiguous ()
129121 img2 = img2 .contiguous ()
130122
131- return img1 , img2 , flow , valid_flow_mask
132-
133- def _pil_or_numpy_to_tensor (
134- self , img1 : Tensor , img2 : Tensor , flow : Optional [Tensor ], valid_flow_mask : Optional [Tensor ]
135- ) -> Tuple [Tensor , Tensor , Optional [Tensor ], Optional [Tensor ]]:
136- if not isinstance (img1 , Tensor ):
137- img1 = F .pil_to_tensor (img1 )
138- if not isinstance (img2 , Tensor ):
139- img2 = F .pil_to_tensor (img2 )
140-
141- if flow is not None and not isinstance (flow , Tensor ):
142- flow = torch .from_numpy (flow )
143- if valid_flow_mask is not None and not isinstance (valid_flow_mask , Tensor ):
144- valid_flow_mask = torch .from_numpy (valid_flow_mask )
145-
146- return img1 , img2 , flow , valid_flow_mask
123+ return img1 , img2
0 commit comments