22import torch
33
44from collections import OrderedDict
5- from torch import Tensor
5+ from torch import Tensor , nn
66from typing import List , Tuple
77
88from torchvision .ops .misc import FrozenBatchNorm2d
99
1010
11- class BalancedPositiveNegativeSampler (object ):
11+ class BalancedPositiveNegativeSampler (nn . Module ):
1212 """
1313 This class samples batches, ensuring that they contain a fixed proportion of positives
1414 """
@@ -20,10 +20,11 @@ def __init__(self, batch_size_per_image, positive_fraction):
2020 batch_size_per_image (int): number of elements to be selected per image
2121 positive_fraction (float): percentace of positive elements per batch
2222 """
23+ super ().__init__ ()
2324 self .batch_size_per_image = batch_size_per_image
2425 self .positive_fraction = positive_fraction
2526
26- def __call__ (self , matched_idxs ):
27+ def forward (self , matched_idxs ):
2728 # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
2829 """
2930 Args:
@@ -126,7 +127,7 @@ def encode_boxes(reference_boxes, proposals, weights):
126127 return targets
127128
128129
129- class BoxCoder (object ):
130+ class BoxCoder (nn . Module ):
130131 """
131132 This class encodes and decodes a set of bounding boxes into
132133 the representation used for training the regressors.
@@ -139,6 +140,7 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
139140 weights (4-element tuple)
140141 bbox_xform_clip (float)
141142 """
143+ super ().__init__ ()
142144 self .weights = weights
143145 self .bbox_xform_clip = bbox_xform_clip
144146
@@ -228,7 +230,7 @@ def decode_single(self, rel_codes, boxes):
228230 return pred_boxes
229231
230232
231- class Matcher (object ):
233+ class Matcher (nn . Module ):
232234 """
233235 This class assigns to each predicted "element" (e.g., a box) a ground-truth
234236 element. Each predicted element will have exactly zero or one matches; each
@@ -266,14 +268,18 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
266268 for predictions that have only low-quality match candidates. See
267269 set_low_quality_matches_ for more details.
268270 """
271+ super ().__init__ ()
269272 self .BELOW_LOW_THRESHOLD = - 1
270273 self .BETWEEN_THRESHOLDS = - 2
271274 assert low_threshold <= high_threshold
272275 self .high_threshold = high_threshold
273276 self .low_threshold = low_threshold
274277 self .allow_low_quality_matches = allow_low_quality_matches
275278
276- def __call__ (self , match_quality_matrix ):
279+ def forward (self , match_quality_matrix ):
280+ return self ._forward_impl (match_quality_matrix )
281+
282+ def _forward_impl (self , match_quality_matrix ):
277283 """
278284 Args:
279285 match_quality_matrix (Tensor[float]): an MxN tensor, containing the
@@ -354,8 +360,8 @@ class SSDMatcher(Matcher):
354360 def __init__ (self , threshold ):
355361 super ().__init__ (threshold , threshold , allow_low_quality_matches = False )
356362
357- def __call__ (self , match_quality_matrix ):
358- matches = super (). __call__ (match_quality_matrix )
363+ def forward (self , match_quality_matrix ):
364+ matches = self . _forward_impl (match_quality_matrix )
359365
360366 # For each gt, find the prediction with which it has the highest quality
361367 _ , highest_quality_pred_foreach_gt = match_quality_matrix .max (dim = 1 )
0 commit comments