33from typing import List , Tuple
44
55import torch
6- from torch import Tensor
6+ from torch import Tensor , nn
77from torchvision .ops .misc import FrozenBatchNorm2d
88
99
@@ -12,18 +12,16 @@ class BalancedPositiveNegativeSampler(object):
1212 This class samples batches, ensuring that they contain a fixed proportion of positives
1313 """
1414
15- def __init__ (self , batch_size_per_image , positive_fraction ):
16- # type: (int, float) -> None
15+ def __init__ (self , batch_size_per_image : int , positive_fraction : float ) -> None :
1716 """
1817 Args:
1918 batch_size_per_image (int): number of elements to be selected per image
20- positive_fraction (float): percentace of positive elements per batch
19+ positive_fraction (float): percentage of positive elements per batch
2120 """
2221 self .batch_size_per_image = batch_size_per_image
2322 self .positive_fraction = positive_fraction
2423
25- def __call__ (self , matched_idxs ):
26- # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
24+ def __call__ (self , matched_idxs : List [Tensor ]) -> Tuple [List [Tensor ], List [Tensor ]]:
2725 """
2826 Args:
2927 matched idxs: list of tensors containing -1, 0 or positive values.
@@ -73,8 +71,7 @@ def __call__(self, matched_idxs):
7371
7472
7573@torch .jit ._script_if_tracing
76- def encode_boxes (reference_boxes , proposals , weights ):
77- # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
74+ def encode_boxes (reference_boxes : Tensor , proposals : Tensor , weights : Tensor ) -> Tensor :
7875 """
7976 Encode a set of proposals with respect to some
8077 reference boxes
@@ -127,8 +124,9 @@ class BoxCoder(object):
127124 the representation used for training the regressors.
128125 """
129126
130- def __init__ (self , weights , bbox_xform_clip = math .log (1000.0 / 16 )):
131- # type: (Tuple[float, float, float, float], float) -> None
127+ def __init__ (
128+ self , weights : Tuple [float , float , float , float ], bbox_xform_clip : float = math .log (1000.0 / 16 )
129+ ) -> None :
132130 """
133131 Args:
134132 weights (4-element tuple)
@@ -137,15 +135,14 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
137135 self .weights = weights
138136 self .bbox_xform_clip = bbox_xform_clip
139137
140- def encode (self , reference_boxes , proposals ):
141- # type: (List[Tensor], List[Tensor]) -> List[Tensor]
138+ def encode (self , reference_boxes : List [Tensor ], proposals : List [Tensor ]) -> List [Tensor ]:
142139 boxes_per_image = [len (b ) for b in reference_boxes ]
143140 reference_boxes = torch .cat (reference_boxes , dim = 0 )
144141 proposals = torch .cat (proposals , dim = 0 )
145142 targets = self .encode_single (reference_boxes , proposals )
146143 return targets .split (boxes_per_image , 0 )
147144
148- def encode_single (self , reference_boxes , proposals ) :
145+ def encode_single (self , reference_boxes : Tensor , proposals : Tensor ) -> Tensor :
149146 """
150147 Encode a set of proposals with respect to some
151148 reference boxes
@@ -161,8 +158,7 @@ def encode_single(self, reference_boxes, proposals):
161158
162159 return targets
163160
164- def decode (self , rel_codes , boxes ):
165- # type: (Tensor, List[Tensor]) -> Tensor
161+ def decode (self , rel_codes : Tensor , boxes : List [Tensor ]) -> Tensor :
166162 assert isinstance (boxes , (list , tuple ))
167163 assert isinstance (rel_codes , torch .Tensor )
168164 boxes_per_image = [b .size (0 ) for b in boxes ]
@@ -177,7 +173,7 @@ def decode(self, rel_codes, boxes):
177173 pred_boxes = pred_boxes .reshape (box_sum , - 1 , 4 )
178174 return pred_boxes
179175
180- def decode_single (self , rel_codes , boxes ) :
176+ def decode_single (self , rel_codes : Tensor , boxes : Tensor ) -> Tensor :
181177 """
182178 From a set of original boxes and encoded relative box offsets,
183179 get the decoded boxes.
@@ -244,8 +240,7 @@ class Matcher(object):
244240 "BETWEEN_THRESHOLDS" : int ,
245241 }
246242
247- def __init__ (self , high_threshold , low_threshold , allow_low_quality_matches = False ):
248- # type: (float, float, bool) -> None
243+ def __init__ (self , high_threshold : float , low_threshold : float , allow_low_quality_matches : bool = False ) -> None :
249244 """
250245 Args:
251246 high_threshold (float): quality values greater than or equal to
@@ -266,7 +261,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
266261 self .low_threshold = low_threshold
267262 self .allow_low_quality_matches = allow_low_quality_matches
268263
269- def __call__ (self , match_quality_matrix ) :
264+ def __call__ (self , match_quality_matrix : Tensor ) -> Tensor :
270265 """
271266 Args:
272267 match_quality_matrix (Tensor[float]): an MxN tensor, containing the
@@ -290,7 +285,7 @@ def __call__(self, match_quality_matrix):
290285 if self .allow_low_quality_matches :
291286 all_matches = matches .clone ()
292287 else :
293- all_matches = None
288+ all_matches = None # type: ignore[assignment]
294289
295290 # Assign candidate matches with low quality to negative (unassigned) values
296291 below_low_threshold = matched_vals < self .low_threshold
@@ -304,7 +299,7 @@ def __call__(self, match_quality_matrix):
304299
305300 return matches
306301
307- def set_low_quality_matches_ (self , matches , all_matches , match_quality_matrix ) :
302+ def set_low_quality_matches_ (self , matches : Tensor , all_matches : Tensor , match_quality_matrix : Tensor ) -> None :
308303 """
309304 Produce additional matches for predictions that have only low-quality matches.
310305 Specifically, for each ground-truth find the set of predictions that have
@@ -335,10 +330,10 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
335330
336331
337332class SSDMatcher (Matcher ):
338- def __init__ (self , threshold ) :
333+ def __init__ (self , threshold : float ) -> None :
339334 super ().__init__ (threshold , threshold , allow_low_quality_matches = False )
340335
341- def __call__ (self , match_quality_matrix ) :
336+ def __call__ (self , match_quality_matrix : Tensor ) -> Tensor :
342337 matches = super ().__call__ (match_quality_matrix )
343338
344339 # For each gt, find the prediction with which it has the highest quality
@@ -350,7 +345,7 @@ def __call__(self, match_quality_matrix):
350345 return matches
351346
352347
353- def overwrite_eps (model , eps ) :
348+ def overwrite_eps (model : nn . Module , eps : float ) -> None :
354349 """
355350 This method overwrites the default eps values of all the
356351 FrozenBatchNorm2d layers of the model with the provided value.
@@ -368,7 +363,7 @@ def overwrite_eps(model, eps):
368363 module .eps = eps
369364
370365
371- def retrieve_out_channels (model , size ) :
366+ def retrieve_out_channels (model : nn . Module , size : Tuple [ int , int ]) -> List [ int ] :
372367 """
373368 This method retrieves the number of output channels of a specific model.
374369
0 commit comments