Skip to content

Commit eab6064

Browse files
kazhangfacebook-github-bot
authored andcommitted
[fbsync] Adding multiweight support to SSD. (#4881)
Reviewed By: datumbox Differential Revision: D32298973 fbshipit-source-id: e74e8cfc564f13681466a4006a3a879731b597ec
1 parent 8f8fbab commit eab6064

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

torchvision/prototype/models/detection/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from .keypoint_rcnn import *
33
from .mask_rcnn import *
44
from .retinanet import *
5+
from .ssd import *
56
from .ssdlite import *
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import warnings
2+
from typing import Any, Optional
3+
4+
from torchvision.transforms.functional import InterpolationMode
5+
6+
from ....models.detection.ssd import (
7+
_validate_trainable_layers,
8+
_vgg_extractor,
9+
DefaultBoxGenerator,
10+
SSD,
11+
)
12+
from ...transforms.presets import CocoEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _COCO_CATEGORIES
15+
from ..vgg import VGG16Weights, vgg16
16+
17+
18+
__all__ = [
19+
"SSD300VGG16Weights",
20+
"ssd300_vgg16",
21+
]
22+
23+
24+
class SSD300VGG16Weights(Weights):
25+
Coco_RefV1 = WeightEntry(
26+
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
27+
transforms=CocoEval,
28+
meta={
29+
"size": (300, 300),
30+
"categories": _COCO_CATEGORIES,
31+
"interpolation": InterpolationMode.BILINEAR,
32+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
33+
"map": 25.1,
34+
},
35+
)
36+
37+
38+
def ssd300_vgg16(
39+
weights: Optional[SSD300VGG16Weights] = None,
40+
weights_backbone: Optional[VGG16Weights] = None,
41+
progress: bool = True,
42+
num_classes: int = 91,
43+
trainable_backbone_layers: Optional[int] = None,
44+
**kwargs: Any,
45+
) -> SSD:
46+
if "pretrained" in kwargs:
47+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
48+
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
49+
weights = SSD300VGG16Weights.verify(weights)
50+
if "pretrained_backbone" in kwargs:
51+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
52+
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
53+
weights_backbone = VGG16Weights.verify(weights_backbone)
54+
55+
if "size" in kwargs:
56+
warnings.warn("The size of the model is already fixed; ignoring the argument.")
57+
58+
if weights is not None:
59+
weights_backbone = None
60+
num_classes = len(weights.meta["categories"])
61+
62+
trainable_backbone_layers = _validate_trainable_layers(
63+
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
64+
)
65+
66+
# Use custom backbones more appropriate for SSD
67+
backbone = vgg16(weights=weights_backbone, progress=progress)
68+
backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
69+
anchor_generator = DefaultBoxGenerator(
70+
[[2], [2, 3], [2, 3], [2, 3], [2], [2]],
71+
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
72+
steps=[8, 16, 32, 64, 100, 300],
73+
)
74+
75+
defaults = {
76+
# Rescale the input in a way compatible to the backbone
77+
"image_mean": [0.48235, 0.45882, 0.40784],
78+
"image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
79+
}
80+
kwargs: Any = {**defaults, **kwargs}
81+
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
82+
83+
if weights is not None:
84+
model.load_state_dict(weights.state_dict(progress=progress))
85+
86+
return model

torchvision/prototype/models/vgg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ class VGG16Weights(Weights):
106106
"acc@5": 90.382,
107107
},
108108
)
109+
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
110+
# same input standardization method as the paper. Only the `features` weights have proper values, those on the
111+
# `classifier` module are filled with nans.
112+
ImageNet1K_Features = WeightEntry(
113+
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
114+
transforms=partial(
115+
ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0)
116+
),
117+
meta={
118+
"size": (224, 224),
119+
"categories": None,
120+
"interpolation": InterpolationMode.BILINEAR,
121+
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
122+
"acc@1": float("nan"),
123+
"acc@5": float("nan"),
124+
},
125+
)
109126

110127

111128
class VGG16BNWeights(Weights):

0 commit comments

Comments
 (0)