-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Adding multiweight support to SSD #4881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| import warnings | ||
| from typing import Any, Optional | ||
|
|
||
| from torchvision.transforms.functional import InterpolationMode | ||
|
|
||
| from ....models.detection.ssd import ( | ||
| _validate_trainable_layers, | ||
| _vgg_extractor, | ||
| DefaultBoxGenerator, | ||
| SSD, | ||
| ) | ||
| from ...transforms.presets import CocoEval | ||
| from .._api import Weights, WeightEntry | ||
| from .._meta import _COCO_CATEGORIES | ||
| from ..vgg import VGG16Weights, vgg16 | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "SSD300VGG16Weights", | ||
| "ssd300_vgg16", | ||
| ] | ||
|
|
||
|
|
||
| class SSD300VGG16Weights(Weights): | ||
| Coco_RefV1 = WeightEntry( | ||
| url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", | ||
| transforms=CocoEval, | ||
| meta={ | ||
| "size": (300, 300), | ||
| "categories": _COCO_CATEGORIES, | ||
| "interpolation": InterpolationMode.BILINEAR, | ||
| "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", | ||
| "map": 25.1, | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| def ssd300_vgg16( | ||
| weights: Optional[SSD300VGG16Weights] = None, | ||
| weights_backbone: Optional[VGG16Weights] = None, | ||
| progress: bool = True, | ||
| num_classes: int = 91, | ||
| trainable_backbone_layers: Optional[int] = None, | ||
| **kwargs: Any, | ||
| ) -> SSD: | ||
| if "pretrained" in kwargs: | ||
| warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
| weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None | ||
| weights = SSD300VGG16Weights.verify(weights) | ||
| if "pretrained_backbone" in kwargs: | ||
| warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") | ||
| weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None | ||
| weights_backbone = VGG16Weights.verify(weights_backbone) | ||
|
|
||
| if "size" in kwargs: | ||
| warnings.warn("The size of the model is already fixed; ignoring the argument.") | ||
|
|
||
| if weights is not None: | ||
| weights_backbone = None | ||
| num_classes = len(weights.meta["categories"]) | ||
|
|
||
| trainable_backbone_layers = _validate_trainable_layers( | ||
| weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 | ||
| ) | ||
|
|
||
| # Use custom backbones more appropriate for SSD | ||
| backbone = vgg16(weights=weights_backbone, progress=progress) | ||
| backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) | ||
| anchor_generator = DefaultBoxGenerator( | ||
| [[2], [2, 3], [2, 3], [2, 3], [2], [2]], | ||
| scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], | ||
| steps=[8, 16, 32, 64, 100, 300], | ||
| ) | ||
|
|
||
| defaults = { | ||
| # Rescale the input in a way compatible to the backbone | ||
| "image_mean": [0.48235, 0.45882, 0.40784], | ||
| "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor | ||
| } | ||
| kwargs: Any = {**defaults, **kwargs} | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any type declaration needed to mypy. |
||
| model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we get the size from the weights?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the same discussion as in #4875 (comment). Note that unlike the FasterRCNN models that adapt better on variable input size, SSD doesn't. It was investigated at #3819 and that's why it's hardcoded. |
||
|
|
||
| if weights is not None: | ||
| model.load_state_dict(weights.state_dict(progress=progress)) | ||
|
|
||
| return model | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,23 @@ class VGG16Weights(Weights): | |
| "acc@5": 90.382, | ||
| }, | ||
| ) | ||
| # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the | ||
| # same input standardization method as the paper. Only the `features` weights have proper values, those on the | ||
| # `classifier` module are filled with nans. | ||
| ImageNet1K_Features = WeightEntry( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are adding the weights as a separate entry on the VGG. This one only has the features part of the model. |
||
| url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", | ||
| transforms=partial( | ||
| ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) | ||
| ), | ||
| meta={ | ||
| "size": (224, 224), | ||
| "categories": None, | ||
| "interpolation": InterpolationMode.BILINEAR, | ||
| "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", | ||
| "acc@1": float("nan"), | ||
| "acc@5": float("nan"), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't contain a classifier, hence the nans here to denote it. |
||
| }, | ||
| ) | ||
|
|
||
|
|
||
| class VGG16BNWeights(Weights): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using special weights here. Not the standard VGG weights of TorchVision.