1- from typing import List , Optional , Dict , Tuple
1+ from typing import List , Optional , Dict , Tuple , cast
22
33import torch
44import torchvision
1414
1515
1616@torch .jit .unused
17- def _onnx_get_num_anchors_and_pre_nms_top_n (ob , orig_pre_nms_top_n ):
18- # type: (Tensor, int) -> Tuple[int, int]
17+ def _onnx_get_num_anchors_and_pre_nms_top_n (ob : Tensor , orig_pre_nms_top_n : int ) -> Tuple [int , int ]:
1918 from torch .onnx import operators
2019
2120 num_anchors = operators .shape_as_tensor (ob )[1 ].unsqueeze (0 )
2221 pre_nms_top_n = torch .min (torch .cat ((torch .tensor ([orig_pre_nms_top_n ], dtype = num_anchors .dtype ), num_anchors ), 0 ))
2322
24- return num_anchors , pre_nms_top_n
23+ # for mypy we cast at runtime
24+ return cast (int , num_anchors ), cast (int , pre_nms_top_n )
2525
2626
2727class RPNHead (nn .Module ):
@@ -33,18 +33,17 @@ class RPNHead(nn.Module):
3333 num_anchors (int): number of anchors to be predicted
3434 """
3535
36- def __init__ (self , in_channels , num_anchors ) :
36+ def __init__ (self , in_channels : int , num_anchors : int ) -> None :
3737 super (RPNHead , self ).__init__ ()
3838 self .conv = nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
3939 self .cls_logits = nn .Conv2d (in_channels , num_anchors , kernel_size = 1 , stride = 1 )
4040 self .bbox_pred = nn .Conv2d (in_channels , num_anchors * 4 , kernel_size = 1 , stride = 1 )
4141
4242 for layer in self .children ():
43- torch .nn .init .normal_ (layer .weight , std = 0.01 )
44- torch .nn .init .constant_ (layer .bias , 0 )
43+ torch .nn .init .normal_ (layer .weight , std = 0.01 ) # type: ignore[arg-type]
44+ torch .nn .init .constant_ (layer .bias , 0 ) # type: ignore[arg-type]
4545
46- def forward (self , x ):
47- # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
46+ def forward (self , x : List [Tensor ]) -> Tuple [List [Tensor ], List [Tensor ]]:
4847 logits = []
4948 bbox_reg = []
5049 for feature in x :
@@ -54,16 +53,14 @@ def forward(self, x):
5453 return logits , bbox_reg
5554
5655
57- def permute_and_flatten (layer , N , A , C , H , W ):
58- # type: (Tensor, int, int, int, int, int) -> Tensor
56+ def permute_and_flatten (layer : Tensor , N : int , A : int , C : int , H : int , W : int ) -> Tensor :
5957 layer = layer .view (N , - 1 , C , H , W )
6058 layer = layer .permute (0 , 3 , 4 , 1 , 2 )
6159 layer = layer .reshape (N , - 1 , C )
6260 return layer
6361
6462
65- def concat_box_prediction_layers (box_cls , box_regression ):
66- # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
63+ def concat_box_prediction_layers (box_cls : List [Tensor ], box_regression : List [Tensor ]) -> Tuple [Tensor , Tensor ]:
6764 box_cls_flattened = []
6865 box_regression_flattened = []
6966 # for each feature level, permute the outputs to make them be in the
@@ -104,10 +101,10 @@ class RegionProposalNetwork(torch.nn.Module):
104101 for computing the loss
105102 positive_fraction (float): proportion of positive anchors in a mini-batch during training
106103 of the RPN
107- pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
104+ pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
108105 contain two fields: training and testing, to allow for different values depending
109106 on training or evaluation
110- post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
107+ post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
111108 contain two fields: training and testing, to allow for different values depending
112109 on training or evaluation
113110 nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
@@ -118,25 +115,23 @@ class RegionProposalNetwork(torch.nn.Module):
118115 "box_coder" : det_utils .BoxCoder ,
119116 "proposal_matcher" : det_utils .Matcher ,
120117 "fg_bg_sampler" : det_utils .BalancedPositiveNegativeSampler ,
121- "pre_nms_top_n" : Dict [str , int ],
122- "post_nms_top_n" : Dict [str , int ],
123118 }
124119
125120 def __init__ (
126121 self ,
127- anchor_generator ,
128- head ,
129- #
130- fg_iou_thresh ,
131- bg_iou_thresh ,
132- batch_size_per_image ,
133- positive_fraction ,
134- #
135- pre_nms_top_n ,
136- post_nms_top_n ,
137- nms_thresh ,
138- score_thresh = 0.0 ,
139- ):
122+ anchor_generator : AnchorGenerator ,
123+ head : nn . Module ,
124+ # Faster-RCNN Training
125+ fg_iou_thresh : float ,
126+ bg_iou_thresh : float ,
127+ batch_size_per_image : int ,
128+ positive_fraction : float ,
129+ # Faster-RCNN Inference
130+ pre_nms_top_n : Dict [ str , int ] ,
131+ post_nms_top_n : Dict [ str , int ] ,
132+ nms_thresh : float ,
133+ score_thresh : float = 0.0 ,
134+ ) -> None :
140135 super (RegionProposalNetwork , self ).__init__ ()
141136 self .anchor_generator = anchor_generator
142137 self .head = head
@@ -159,18 +154,20 @@ def __init__(
159154 self .score_thresh = score_thresh
160155 self .min_size = 1e-3
161156
162- def pre_nms_top_n (self ):
157+ def pre_nms_top_n (self ) -> int :
163158 if self .training :
164159 return self ._pre_nms_top_n ["training" ]
165160 return self ._pre_nms_top_n ["testing" ]
166161
167- def post_nms_top_n (self ):
162+ def post_nms_top_n (self ) -> int :
168163 if self .training :
169164 return self ._post_nms_top_n ["training" ]
170165 return self ._post_nms_top_n ["testing" ]
171166
172- def assign_targets_to_anchors (self , anchors , targets ):
173- # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
167+ def assign_targets_to_anchors (
168+ self , anchors : List [Tensor ], targets : List [Dict [str , Tensor ]]
169+ ) -> Tuple [List [Tensor ], List [Tensor ]]:
170+
174171 labels = []
175172 matched_gt_boxes = []
176173 for anchors_per_image , targets_per_image in zip (anchors , targets ):
@@ -205,8 +202,7 @@ def assign_targets_to_anchors(self, anchors, targets):
205202 matched_gt_boxes .append (matched_gt_boxes_per_image )
206203 return labels , matched_gt_boxes
207204
208- def _get_top_n_idx (self , objectness , num_anchors_per_level ):
209- # type: (Tensor, List[int]) -> Tensor
205+ def _get_top_n_idx (self , objectness : Tensor , num_anchors_per_level : List [int ]) -> Tensor :
210206 r = []
211207 offset = 0
212208 for ob in objectness .split (num_anchors_per_level , 1 ):
@@ -220,8 +216,14 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
220216 offset += num_anchors
221217 return torch .cat (r , dim = 1 )
222218
223- def filter_proposals (self , proposals , objectness , image_shapes , num_anchors_per_level ):
224- # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
219+ def filter_proposals (
220+ self ,
221+ proposals : Tensor ,
222+ objectness : Tensor ,
223+ image_shapes : List [Tuple [int , int ]],
224+ num_anchors_per_level : List [int ],
225+ ) -> Tuple [List [Tensor ], List [Tensor ]]:
226+
225227 num_images = proposals .shape [0 ]
226228 device = proposals .device
227229 # do not backprop through objectness
@@ -271,8 +273,9 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
271273 final_scores .append (scores )
272274 return final_boxes , final_scores
273275
274- def compute_loss (self , objectness , pred_bbox_deltas , labels , regression_targets ):
275- # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
276+ def compute_loss (
277+ self , objectness : Tensor , pred_bbox_deltas : Tensor , labels : List [Tensor ], regression_targets : List [Tensor ]
278+ ) -> Tuple [Tensor , Tensor ]:
276279 """
277280 Args:
278281 objectness (Tensor)
@@ -312,25 +315,25 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
312315
313316 def forward (
314317 self ,
315- images , # type : ImageList
316- features , # type : Dict[str, Tensor]
317- targets = None , # type : Optional[List[Dict[str, Tensor]]]
318- ):
319- # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
318+ images : ImageList ,
319+ features : Dict [str , Tensor ],
320+ targets : Optional [List [Dict [str , Tensor ]]] = None ,
321+ ) -> Tuple [ List [ Tensor ], Dict [ str , Tensor ]] :
322+
320323 """
321324 Args:
322325 images (ImageList): images for which we want to compute the predictions
323- features (OrderedDict[ Tensor]): features computed from the images that are
326+ features (Dict[str, Tensor]): features computed from the images that are
324327 used for computing the predictions. Each tensor in the list
325328 correspond to different feature levels
326- targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional).
329+ targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
327330 If provided, each element in the dict should contain a field `boxes`,
328331 with the locations of the ground-truth boxes.
329332
330333 Returns:
331334 boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
332335 image.
333- losses (Dict[Tensor]): the losses for the model during training. During
336+ losses (Dict[str, Tensor]): the losses for the model during training. During
334337 testing, it is an empty dict.
335338 """
336339 # RPN uses all feature maps that are available
0 commit comments