From a3f4606a3ada31f401493e8afa40c066e9e4761b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Apr 2022 15:35:56 +0100 Subject: [PATCH 1/2] Add MaskRCNN improved weights --- torchvision/models/detection/mask_rcnn.py | 27 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 01e56c7a108..911eba50af8 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -349,17 +349,22 @@ def __init__(self, in_channels, dim_reduced, num_classes): # nn.init.constant_(param, 0) +_COMMON_META = { + "task": "image_object_detection", + "architecture": "MaskRCNN", + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", transforms=ObjectDetection, meta={ - "task": "image_object_detection", - "architecture": "MaskRCNN", + **_COMMON_META, "publication_year": 2017, "num_params": 44401393, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", "map": 37.9, "map_mask": 34.6, @@ -369,7 +374,19 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): - pass + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "publication_year": 2021, + "num_params": 46359409, + "recipe": "", + "map": 47.4, + "map_mask": 41.8, + }, + ) + DEFAULT = COCO_V1 @handle_legacy_interface( From 13a6ab16ccee52cb5db588fd34a37e7ba29291d8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 6 Apr 2022 15:36:41 +0100 Subject: [PATCH 2/2] Adding recipe URL --- torchvision/models/detection/mask_rcnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 911eba50af8..65c85922e2a 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -381,7 +381,7 @@ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): **_COMMON_META, "publication_year": 2021, "num_params": 46359409, - "recipe": "", + "recipe": "https://github.com/pytorch/vision/pull/5773", "map": 47.4, "map_mask": 41.8, },