Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import torch

from collections import OrderedDict
from torch import Tensor
from torch import Tensor, nn
from typing import List, Tuple

from torchvision.ops.misc import FrozenBatchNorm2d


class BalancedPositiveNegativeSampler(object):
class BalancedPositiveNegativeSampler(nn.Module):
"""
This class samples batches, ensuring that they contain a fixed proportion of positives
"""
Expand All @@ -20,10 +20,11 @@ def __init__(self, batch_size_per_image, positive_fraction):
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentace of positive elements per batch
"""
super().__init__()
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction

def __call__(self, matched_idxs):
def forward(self, matched_idxs):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
"""
Args:
Expand Down Expand Up @@ -126,7 +127,7 @@ def encode_boxes(reference_boxes, proposals, weights):
return targets


class BoxCoder(object):
class BoxCoder(nn.Module):
Copy link
Contributor Author

@datumbox datumbox Sep 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class with its encode/decode methods does not play nice with the nn.Module's forward() approach. Not sure if we should convert it but it will allow us to drop its declarations to __annotations__

"""
This class encodes and decodes a set of bounding boxes into
the representation used for training the regressors.
Expand All @@ -139,6 +140,7 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
weights (4-element tuple)
bbox_xform_clip (float)
"""
super().__init__()
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip

Expand Down Expand Up @@ -228,7 +230,7 @@ def decode_single(self, rel_codes, boxes):
return pred_boxes


class Matcher(object):
class Matcher(nn.Module):
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
element. Each predicted element will have exactly zero or one matches; each
Expand Down Expand Up @@ -266,14 +268,18 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals
for predictions that have only low-quality match candidates. See
set_low_quality_matches_ for more details.
"""
super().__init__()
self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2
assert low_threshold <= high_threshold
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches

def __call__(self, match_quality_matrix):
def forward(self, match_quality_matrix):
return self._forward_impl(match_quality_matrix)

def _forward_impl(self, match_quality_matrix):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround for overwriting the forward on inheriting classes. We do the same on quantization.

"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
Expand Down Expand Up @@ -354,8 +360,8 @@ class SSDMatcher(Matcher):
def __init__(self, threshold):
super().__init__(threshold, threshold, allow_low_quality_matches=False)

def __call__(self, match_quality_matrix):
matches = super().__call__(match_quality_matrix)
def forward(self, match_quality_matrix):
matches = self._forward_impl(match_quality_matrix)

# For each gt, find the prediction with which it has the highest quality
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
Expand Down
7 changes: 0 additions & 7 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ class RetinaNetRegressionHead(nn.Module):
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
}

def __init__(self, in_channels, num_anchors):
super().__init__()
Expand Down Expand Up @@ -309,10 +306,6 @@ class RetinaNet(nn.Module):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
}

def __init__(self, backbone, num_classes,
# transform parameters
Expand Down
5 changes: 0 additions & 5 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,6 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):


class RoIHeads(nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
}

def __init__(self,
box_roi_pool,
Expand Down
3 changes: 0 additions & 3 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ class RegionProposalNetwork(torch.nn.Module):

"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
'pre_nms_top_n': Dict[str, int],
'post_nms_top_n': Dict[str, int],
}
Expand Down
4 changes: 0 additions & 4 deletions torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ class SSD(nn.Module):
proposals used during the training of the classification head. It is used to estimate the negative to
positive ratio.
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
}

def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator,
size: Tuple[int, int], num_classes: int,
Expand Down