From 91284afcc366d2db460e2a4098e73c9e2e4d9230 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 10 Sep 2021 10:30:03 +0100 Subject: [PATCH] Rewrite objects to modules. --- torchvision/models/detection/_utils.py | 22 ++++++++++++++-------- torchvision/models/detection/retinanet.py | 7 ------- torchvision/models/detection/roi_heads.py | 5 ----- torchvision/models/detection/rpn.py | 3 --- torchvision/models/detection/ssd.py | 4 ---- 5 files changed, 14 insertions(+), 27 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 1d3bcdba7fe..bd599190453 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -2,13 +2,13 @@ import torch from collections import OrderedDict -from torch import Tensor +from torch import Tensor, nn from typing import List, Tuple from torchvision.ops.misc import FrozenBatchNorm2d -class BalancedPositiveNegativeSampler(object): +class BalancedPositiveNegativeSampler(nn.Module): """ This class samples batches, ensuring that they contain a fixed proportion of positives """ @@ -20,10 +20,11 @@ def __init__(self, batch_size_per_image, positive_fraction): batch_size_per_image (int): number of elements to be selected per image positive_fraction (float): percentace of positive elements per batch """ + super().__init__() self.batch_size_per_image = batch_size_per_image self.positive_fraction = positive_fraction - def __call__(self, matched_idxs): + def forward(self, matched_idxs): # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] """ Args: @@ -126,7 +127,7 @@ def encode_boxes(reference_boxes, proposals, weights): return targets -class BoxCoder(object): +class BoxCoder(nn.Module): """ This class encodes and decodes a set of bounding boxes into the representation used for training the regressors. @@ -139,6 +140,7 @@ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): weights (4-element tuple) bbox_xform_clip (float) """ + super().__init__() self.weights = weights self.bbox_xform_clip = bbox_xform_clip @@ -228,7 +230,7 @@ def decode_single(self, rel_codes, boxes): return pred_boxes -class Matcher(object): +class Matcher(nn.Module): """ This class assigns to each predicted "element" (e.g., a box) a ground-truth element. Each predicted element will have exactly zero or one matches; each @@ -266,6 +268,7 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals for predictions that have only low-quality match candidates. See set_low_quality_matches_ for more details. """ + super().__init__() self.BELOW_LOW_THRESHOLD = -1 self.BETWEEN_THRESHOLDS = -2 assert low_threshold <= high_threshold @@ -273,7 +276,10 @@ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=Fals self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches - def __call__(self, match_quality_matrix): + def forward(self, match_quality_matrix): + return self._forward_impl(match_quality_matrix) + + def _forward_impl(self, match_quality_matrix): """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the @@ -354,8 +360,8 @@ class SSDMatcher(Matcher): def __init__(self, threshold): super().__init__(threshold, threshold, allow_low_quality_matches=False) - def __call__(self, match_quality_matrix): - matches = super().__call__(match_quality_matrix) + def forward(self, match_quality_matrix): + matches = self._forward_impl(match_quality_matrix) # For each gt, find the prediction with which it has the highest quality _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 4dd95285dbc..d9dbcd4fae6 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -153,9 +153,6 @@ class RetinaNetRegressionHead(nn.Module): in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted """ - __annotations__ = { - 'box_coder': det_utils.BoxCoder, - } def __init__(self, in_channels, num_anchors): super().__init__() @@ -309,10 +306,6 @@ class RetinaNet(nn.Module): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - } def __init__(self, backbone, num_classes, # transform parameters diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 9948d5f537f..89c8986dc8b 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -483,11 +483,6 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): class RoIHeads(nn.Module): - __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, - } def __init__(self, box_roi_pool, diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index a98eac24dd3..25adb3bab53 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -126,9 +126,6 @@ class RegionProposalNetwork(torch.nn.Module): """ __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, 'pre_nms_top_n': Dict[str, int], 'post_nms_top_n': Dict[str, int], } diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index e67c4930b30..a06adf38da5 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -159,10 +159,6 @@ class SSD(nn.Module): proposals used during the training of the classification head. It is used to estimate the negative to positive ratio. """ - __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - } def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, size: Tuple[int, int], num_classes: int,