From c724c16c3d8e5cb28271796f15fc07999cf9b58e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 6 Nov 2021 10:51:48 +0000 Subject: [PATCH] Adding multiweight support to SSD. --- .../prototype/models/detection/__init__.py | 1 + torchvision/prototype/models/detection/ssd.py | 86 +++++++++++++++++++ torchvision/prototype/models/vgg.py | 17 ++++ 3 files changed, 104 insertions(+) create mode 100644 torchvision/prototype/models/detection/ssd.py diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py index 79862e53cce..5369efc1f04 100644 --- a/torchvision/prototype/models/detection/__init__.py +++ b/torchvision/prototype/models/detection/__init__.py @@ -2,3 +2,4 @@ from .keypoint_rcnn import * from .mask_rcnn import * from .retinanet import * +from .ssd import * diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py new file mode 100644 index 00000000000..5759b8cd40f --- /dev/null +++ b/torchvision/prototype/models/detection/ssd.py @@ -0,0 +1,86 @@ +import warnings +from typing import Any, Optional + +from torchvision.transforms.functional import InterpolationMode + +from ....models.detection.ssd import ( + _validate_trainable_layers, + _vgg_extractor, + DefaultBoxGenerator, + SSD, +) +from ...transforms.presets import CocoEval +from .._api import Weights, WeightEntry +from .._meta import _COCO_CATEGORIES +from ..vgg import VGG16Weights, vgg16 + + +__all__ = [ + "SSD300VGG16Weights", + "ssd300_vgg16", +] + + +class SSD300VGG16Weights(Weights): + Coco_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", + transforms=CocoEval, + meta={ + "size": (300, 300), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", + "map": 25.1, + }, + ) + + +def ssd300_vgg16( + weights: Optional[SSD300VGG16Weights] = None, + weights_backbone: Optional[VGG16Weights] = None, + progress: bool = True, + num_classes: int = 91, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> SSD: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None + weights = SSD300VGG16Weights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None + weights_backbone = VGG16Weights.verify(weights_backbone) + + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the argument.") + + if weights is not None: + weights_backbone = None + num_classes = len(weights.meta["categories"]) + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 + ) + + # Use custom backbones more appropriate for SSD + backbone = vgg16(weights=weights_backbone, progress=progress) + backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) + anchor_generator = DefaultBoxGenerator( + [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], + steps=[8, 16, 32, 64, 100, 300], + ) + + defaults = { + # Rescale the input in a way compatible to the backbone + "image_mean": [0.48235, 0.45882, 0.40784], + "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor + } + kwargs: Any = {**defaults, **kwargs} + model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index d031eece194..0b034a40f51 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -106,6 +106,23 @@ class VGG16Weights(Weights): "acc@5": 90.382, }, ) + # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the + # same input standardization method as the paper. Only the `features` weights have proper values, those on the + # `classifier` module are filled with nans. + ImageNet1K_Features = WeightEntry( + url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", + transforms=partial( + ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) + ), + meta={ + "size": (224, 224), + "categories": None, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", + "acc@1": float("nan"), + "acc@5": float("nan"), + }, + ) class VGG16BNWeights(Weights):