Skip to content

Commit beab59f

Browse files
committed
Rewrite objects to modules.
1 parent 1cb85ab commit beab59f

File tree

6 files changed

+19
-32
lines changed

6 files changed

+19
-32
lines changed

torchvision/models/detection/_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import torch
33

44
from collections import OrderedDict
5-
from torch import Tensor
5+
from torch import Tensor, nn
66
from typing import List, Tuple
77

88
from torchvision.ops.misc import FrozenBatchNorm2d
99

1010

11-
class BalancedPositiveNegativeSampler(object):
11+
class BalancedPositiveNegativeSampler(nn.Module):
1212
"""
1313
This class samples batches, ensuring that they contain a fixed proportion of positives
1414
"""
@@ -20,10 +20,11 @@ def __init__(self, batch_size_per_image, positive_fraction):
2020
batch_size_per_image (int): number of elements to be selected per image
2121
positive_fraction (float): percentace of positive elements per batch
2222
"""
23+
super().__init__()
2324
self.batch_size_per_image = batch_size_per_image
2425
self.positive_fraction = positive_fraction
2526

26-
def __call__(self, matched_idxs):
27+
def forward(self, matched_idxs):
2728
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
2829
"""
2930
Args:
@@ -126,7 +127,7 @@ def encode_boxes(reference_boxes, proposals, weights):
126127
return targets
127128

128129

129-
class BoxCoder(object):
130+
class BoxCoder(nn.Module):
130131
"""
131132
This class encodes and decodes a set of bounding boxes into
132133
the representation used for training the regressors.
@@ -139,6 +140,7 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
139140
weights (4-element tuple)
140141
bbox_xform_clip (float)
141142
"""
143+
super().__init__()
142144
self.weights = weights
143145
self.bbox_xform_clip = bbox_xform_clip
144146

@@ -228,7 +230,7 @@ def decode_single(self, rel_codes, boxes):
228230
return pred_boxes
229231

230232

231-
class Matcher(object):
233+
class Matcher(nn.Module):
232234
"""
233235
This class assigns to each predicted "element" (e.g., a box) a ground-truth
234236
element. Each predicted element will have exactly zero or one matches; each
@@ -266,14 +268,18 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
266268
for predictions that have only low-quality match candidates. See
267269
set_low_quality_matches_ for more details.
268270
"""
271+
super().__init__()
269272
self.BELOW_LOW_THRESHOLD = -1
270273
self.BETWEEN_THRESHOLDS = -2
271274
assert low_threshold <= high_threshold
272275
self.high_threshold = high_threshold
273276
self.low_threshold = low_threshold
274277
self.allow_low_quality_matches = allow_low_quality_matches
275278

276-
def __call__(self, match_quality_matrix):
279+
def forward(self, match_quality_matrix):
280+
return self._forward_impl(match_quality_matrix)
281+
282+
def _forward_impl(self, match_quality_matrix):
277283
"""
278284
Args:
279285
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
@@ -354,8 +360,8 @@ class SSDMatcher(Matcher):
354360
def __init__(self, threshold):
355361
super().__init__(threshold, threshold, allow_low_quality_matches=False)
356362

357-
def __call__(self, match_quality_matrix):
358-
matches = super().__call__(match_quality_matrix)
363+
def forward(self, match_quality_matrix):
364+
matches = self._forward_impl(match_quality_matrix)
359365

360366
# For each gt, find the prediction with which it has the highest quality
361367
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)

torchvision/models/detection/retinanet.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,6 @@ class RetinaNetRegressionHead(nn.Module):
153153
in_channels (int): number of channels of the input feature
154154
num_anchors (int): number of anchors to be predicted
155155
"""
156-
__annotations__ = {
157-
'box_coder': det_utils.BoxCoder,
158-
}
159156

160157
def __init__(self, in_channels, num_anchors):
161158
super().__init__()
@@ -309,10 +306,6 @@ class RetinaNet(nn.Module):
309306
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
310307
>>> predictions = model(x)
311308
"""
312-
__annotations__ = {
313-
'box_coder': det_utils.BoxCoder,
314-
'proposal_matcher': det_utils.Matcher,
315-
}
316309

317310
def __init__(self, backbone, num_classes,
318311
# transform parameters

torchvision/models/detection/roi_heads.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,6 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
483483

484484

485485
class RoIHeads(nn.Module):
486-
__annotations__ = {
487-
'box_coder': det_utils.BoxCoder,
488-
'proposal_matcher': det_utils.Matcher,
489-
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
490-
}
491486

492487
def __init__(self,
493488
box_roi_pool,

torchvision/models/detection/rpn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ class RegionProposalNetwork(torch.nn.Module):
126126
127127
"""
128128
__annotations__ = {
129-
'box_coder': det_utils.BoxCoder,
130-
'proposal_matcher': det_utils.Matcher,
131-
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
132129
'pre_nms_top_n': Dict[str, int],
133130
'post_nms_top_n': Dict[str, int],
134131
}

torchvision/models/detection/ssd.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,6 @@ class SSD(nn.Module):
159159
proposals used during the training of the classification head. It is used to estimate the negative to
160160
positive ratio.
161161
"""
162-
__annotations__ = {
163-
'box_coder': det_utils.BoxCoder,
164-
'proposal_matcher': det_utils.Matcher,
165-
}
166162

167163
def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator,
168164
size: Tuple[int, int], num_classes: int,

torchvision/ops/poolers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def initLevelMapper(
4040
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
4141

4242

43-
class LevelMapper(object):
43+
class LevelMapper(nn.Module):
4444
"""Determine which FPN level each RoI in a set of RoIs should map to based
4545
on the heuristic in the FPN paper.
4646
@@ -60,13 +60,14 @@ def __init__(
6060
canonical_level: int = 4,
6161
eps: float = 1e-6,
6262
):
63+
super().__init__()
6364
self.k_min = k_min
6465
self.k_max = k_max
6566
self.s0 = canonical_scale
6667
self.lvl0 = canonical_level
6768
self.eps = eps
6869

69-
def __call__(self, boxlists: List[Tensor]) -> Tensor:
70+
def forward(self, boxlists: List[Tensor]) -> Tensor:
7071
"""
7172
Args:
7273
boxlists (list[BoxList])
@@ -117,8 +118,7 @@ class MultiScaleRoIAlign(nn.Module):
117118
"""
118119

119120
__annotations__ = {
120-
'scales': Optional[List[float]],
121-
'map_levels': Optional[LevelMapper]
121+
'scales': Optional[List[float]]
122122
}
123123

124124
def __init__(
@@ -137,7 +137,7 @@ def __init__(
137137
self.sampling_ratio = sampling_ratio
138138
self.output_size = tuple(output_size)
139139
self.scales = None
140-
self.map_levels = None
140+
self.map_levels: Optional[LevelMapper] = None
141141
self.canonical_scale = canonical_scale
142142
self.canonical_level = canonical_level
143143

0 commit comments

Comments
 (0)