@@ -25,7 +25,7 @@ class GeneralizedRCNN(nn.Module):
2525 the model
2626 """
2727
28- def __init__ (self , backbone , rpn , roi_heads , transform ) :
28+ def __init__ (self , backbone : nn . Module , rpn : nn . Module , roi_heads : nn . Module , transform : nn . Module ) -> None :
2929 super ().__init__ ()
3030 _log_api_usage_once (self )
3131 self .transform = transform
@@ -36,19 +36,26 @@ def __init__(self, backbone, rpn, roi_heads, transform):
3636 self ._has_warned = False
3737
3838 @torch .jit .unused
39- def eager_outputs (self , losses , detections ):
40- # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
39+ def eager_outputs (
40+ self ,
41+ losses : Dict [str , Tensor ],
42+ detections : List [Dict [str , Tensor ]],
43+ ) -> Union [Dict [str , Tensor ], List [Dict [str , Tensor ]]]:
44+
4145 if self .training :
4246 return losses
4347
4448 return detections
4549
46- def forward (self , images , targets = None ):
47- # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
50+ def forward (
51+ self ,
52+ images : List [Tensor ],
53+ targets : Optional [List [Dict [str , Tensor ]]] = None ,
54+ ) -> Union [Tuple [Dict [str , Tensor ], List [Dict [str , Tensor ]]], Dict [str , Tensor ], List [Dict [str , Tensor ]]]:
4855 """
4956 Args:
5057 images (list[Tensor]): images to be processed
51- targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
58+ targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
5259
5360 Returns:
5461 result (list[BoxList] or dict[Tensor]): the output from the model.
@@ -97,7 +104,7 @@ def forward(self, images, targets=None):
97104 features = OrderedDict ([("0" , features )])
98105 proposals , proposal_losses = self .rpn (images , features , targets )
99106 detections , detector_losses = self .roi_heads (features , proposals , images .image_sizes , targets )
100- detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
107+ detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes ) # type: ignore[operator]
101108
102109 losses = {}
103110 losses .update (detector_losses )
0 commit comments