From c6111978e460ec18cc3fdbebfb8002272050bcb6 Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 12:40:35 +0800 Subject: [PATCH 1/7] fix type hints and move degenerate boxes to a function in torchvision.models.detection.generalized_rcnn --- torchvision/models/detection/faster_rcnn.py | 16 +++------- .../models/detection/generalized_rcnn.py | 32 +++++++++++-------- torchvision/models/detection/image_list.py | 3 +- torchvision/ops/poolers.py | 4 ++- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index c7e6c6d12db..c17e8eb9ad1 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,20 +1,14 @@ -from collections import OrderedDict - import torch -from torch import nn import torch.nn.functional as F +from torch import nn -from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign - -from ..utils import load_state_dict_from_url - +from .backbone_utils import resnet_fpn_backbone from .generalized_rcnn import GeneralizedRCNN -from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads +from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone - +from ..utils import load_state_dict_from_url __all__ = [ "FasterRCNN", "fasterrcnn_resnet50_fpn", @@ -347,7 +341,7 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ - assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + assert 0 <= trainable_backbone_layers <= 5 # dont freeze any layers if pretrained model or backbone is not used if not (pretrained or pretrained_backbone): trainable_backbone_layers = 5 diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 1ee0542c9c6..d10d38ce956 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -4,6 +4,7 @@ """ from collections import OrderedDict +from typing import Union import torch from torch import nn import warnings @@ -11,6 +12,19 @@ from torch import Tensor +def _check_for_degenerate_boxes(targets): + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError("All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}." + .format(degen_bb, target_idx)) + + class GeneralizedRCNN(nn.Module): """ Main class for Generalized R-CNN. @@ -35,7 +49,7 @@ def __init__(self, backbone, rpn, roi_heads, transform): @torch.jit.unused def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] if self.training: return losses @@ -65,7 +79,7 @@ def forward(self, images, targets=None): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError("Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format( - boxes.shape)) + boxes.shape)) else: raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) @@ -79,18 +93,8 @@ def forward(self, images, targets=None): images, targets = self.transform(images, targets) # Check for degenerate boxes - # TODO: Move this to a function if targets is not None: - for target_idx, target in enumerate(targets): - boxes = target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - # print the first degenrate box - bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] - degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invaid box {} for target at index {}." - .format(degen_bb, target_idx)) + _check_for_degenerate_boxes(targets) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): @@ -107,6 +111,6 @@ def forward(self, images, targets=None): if not self._has_warned: warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") self._has_warned = True - return (losses, detections) + return losses, detections else: return self.eager_outputs(losses, detections) diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index dc8987a9f83..4c446bf77ea 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -1,7 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -import torch -from torch.jit.annotations import List, Tuple from torch import Tensor +from torch.jit.annotations import List, Tuple class ImageList(object): diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index 32734cff86a..bf9412af056 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,4 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from typing import Union + import torch import torch.nn.functional as F from torch import nn, Tensor @@ -119,7 +121,7 @@ class MultiScaleRoIAlign(nn.Module): def __init__( self, featmap_names: List[str], - output_size: List[int], + output_size: Union[int, Tuple[int], List[int]], sampling_ratio: int, ): super(MultiScaleRoIAlign, self).__init__() From 2d0118ddb17d4661f266c845c3dee70a937d2101 Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 13:53:28 +0800 Subject: [PATCH 2/7] format code --- torchvision/models/detection/generalized_rcnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index d10d38ce956..ef4b11cf006 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -79,7 +79,7 @@ def forward(self, images, targets=None): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError("Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format( - boxes.shape)) + boxes.shape)) else: raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) From 5f6823d34e48fbae439634da1c35623ca05370eb Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 13:55:40 +0800 Subject: [PATCH 3/7] format code --- torchvision/models/detection/generalized_rcnn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index ef4b11cf006..4104ca1eca2 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -78,8 +78,7 @@ def forward(self, images, targets=None): if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + "of shape [N, 4], got {:}.".format(boxes.shape)) else: raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) From 13d841ddf0add39f2cc31236562ad43370cdd5ec Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 15:00:58 +0800 Subject: [PATCH 4/7] changed to static method --- torchvision/models/detection/faster_rcnn.py | 13 ++++++--- .../models/detection/generalized_rcnn.py | 28 +++++++++---------- torchvision/models/detection/image_list.py | 1 + 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index c17e8eb9ad1..d9845d887ab 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,14 +1,19 @@ +from collections import OrderedDict + import torch -import torch.nn.functional as F from torch import nn +import torch.nn.functional as F +from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign -from .backbone_utils import resnet_fpn_backbone + +from ..utils import load_state_dict_from_url + from .generalized_rcnn import GeneralizedRCNN -from .roi_heads import RoIHeads from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform -from ..utils import load_state_dict_from_url +from .backbone_utils import resnet_fpn_backbone __all__ = [ "FasterRCNN", "fasterrcnn_resnet50_fpn", diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 4104ca1eca2..8329c37171e 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -12,19 +12,6 @@ from torch import Tensor -def _check_for_degenerate_boxes(targets): - for target_idx, target in enumerate(targets): - boxes = target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - # print the first degenerate box - bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] - degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) - - class GeneralizedRCNN(nn.Module): """ Main class for Generalized R-CNN. @@ -93,7 +80,7 @@ def forward(self, images, targets=None): # Check for degenerate boxes if targets is not None: - _check_for_degenerate_boxes(targets) + GeneralizedRCNN._check_for_degenerate_boxes(targets) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): @@ -113,3 +100,16 @@ def forward(self, images, targets=None): return losses, detections else: return self.eager_outputs(losses, detections) + + @staticmethod + def _check_for_degenerate_boxes(targets): + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError("All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}." + .format(degen_bb, target_idx)) diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index 4c446bf77ea..c471fa076b8 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -1,4 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch from torch import Tensor from torch.jit.annotations import List, Tuple From f15cf60028e207b28df65fb1e24b3cf795df53bd Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 15:05:40 +0800 Subject: [PATCH 5/7] revert imports --- torchvision/models/detection/faster_rcnn.py | 1 + torchvision/models/detection/image_list.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index d9845d887ab..95419939c88 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -15,6 +15,7 @@ from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone + __all__ = [ "FasterRCNN", "fasterrcnn_resnet50_fpn", ] diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index c471fa076b8..dc8987a9f83 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -from torch import Tensor from torch.jit.annotations import List, Tuple +from torch import Tensor class ImageList(object): From e6058cf02399ba1cc93cc0e803f7dedcf3ac987b Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 15:11:10 +0800 Subject: [PATCH 6/7] changed to method --- torchvision/models/detection/generalized_rcnn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 8329c37171e..d1f520f9e97 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -80,7 +80,7 @@ def forward(self, images, targets=None): # Check for degenerate boxes if targets is not None: - GeneralizedRCNN._check_for_degenerate_boxes(targets) + self._check_for_degenerate_boxes(targets) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): @@ -101,8 +101,7 @@ def forward(self, images, targets=None): else: return self.eager_outputs(losses, detections) - @staticmethod - def _check_for_degenerate_boxes(targets): + def _check_for_degenerate_boxes(self, targets): for target_idx, target in enumerate(targets): boxes = target["boxes"] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] From 110687bed1cfaeaab8403bb335039af679cf3d56 Mon Sep 17 00:00:00 2001 From: ChiangYintso <392711804@qq.com> Date: Tue, 4 Aug 2020 15:35:53 +0800 Subject: [PATCH 7/7] revert procedure for degenerating boxes --- torchvision/models/detection/faster_rcnn.py | 2 +- .../models/detection/generalized_rcnn.py | 29 +++++++++---------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 95419939c88..c7e6c6d12db 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -347,7 +347,7 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ - assert 0 <= trainable_backbone_layers <= 5 + assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 # dont freeze any layers if pretrained model or backbone is not used if not (pretrained or pretrained_backbone): trainable_backbone_layers = 5 diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index d1f520f9e97..4a8fadfe345 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -65,7 +65,8 @@ def forward(self, images, targets=None): if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format(boxes.shape)) + "of shape [N, 4], got {:}.".format( + boxes.shape)) else: raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) @@ -79,8 +80,18 @@ def forward(self, images, targets=None): images, targets = self.transform(images, targets) # Check for degenerate boxes + # TODO: Move this to a function if targets is not None: - self._check_for_degenerate_boxes(targets) + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError("All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}." + .format(degen_bb, target_idx)) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): @@ -97,18 +108,6 @@ def forward(self, images, targets=None): if not self._has_warned: warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") self._has_warned = True - return losses, detections + return (losses, detections) else: return self.eager_outputs(losses, detections) - - def _check_for_degenerate_boxes(self, targets): - for target_idx, target in enumerate(targets): - boxes = target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - # print the first degenerate box - bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] - degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx))