1+ from typing import Any , Optional , Union
2+
13import torch .nn .functional as F
24from torch import nn
35from torchvision .ops import MultiScaleRoIAlign
46
5- from ..._internally_replaced_utils import load_state_dict_from_url
67from ...ops import misc as misc_nn_ops
7- from ..mobilenetv3 import mobilenet_v3_large
8- from ..resnet import resnet50
8+ from ...transforms import ObjectDetectionEval , InterpolationMode
9+ from .._api import WeightsEnum , Weights
10+ from .._meta import _COCO_CATEGORIES
11+ from .._utils import handle_legacy_interface , _ovewrite_value_param
12+ from ..mobilenetv3 import MobileNet_V3_Large_Weights , mobilenet_v3_large
13+ from ..resnet import ResNet50_Weights , resnet50
914from ._utils import overwrite_eps
1015from .anchor_utils import AnchorGenerator
1116from .backbone_utils import _resnet_fpn_extractor , _validate_trainable_layers , _mobilenet_extractor
1722
1823__all__ = [
1924 "FasterRCNN" ,
25+ "FasterRCNN_ResNet50_FPN_Weights" ,
26+ "FasterRCNN_MobileNet_V3_Large_FPN_Weights" ,
27+ "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights" ,
2028 "fasterrcnn_resnet50_fpn" ,
21- "fasterrcnn_mobilenet_v3_large_320_fpn" ,
2229 "fasterrcnn_mobilenet_v3_large_fpn" ,
30+ "fasterrcnn_mobilenet_v3_large_320_fpn" ,
2331]
2432
2533
@@ -307,16 +315,70 @@ def forward(self, x):
307315 return scores , bbox_deltas
308316
309317
310- model_urls = {
311- "fasterrcnn_resnet50_fpn_coco" : "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" ,
312- "fasterrcnn_mobilenet_v3_large_320_fpn_coco" : "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth" ,
313- "fasterrcnn_mobilenet_v3_large_fpn_coco" : "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth" ,
318+ _COMMON_META = {
319+ "task" : "image_object_detection" ,
320+ "architecture" : "FasterRCNN" ,
321+ "publication_year" : 2015 ,
322+ "categories" : _COCO_CATEGORIES ,
323+ "interpolation" : InterpolationMode .BILINEAR ,
314324}
315325
316326
327+ class FasterRCNN_ResNet50_FPN_Weights (WeightsEnum ):
328+ COCO_V1 = Weights (
329+ url = "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" ,
330+ transforms = ObjectDetectionEval ,
331+ meta = {
332+ ** _COMMON_META ,
333+ "num_params" : 41755286 ,
334+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn" ,
335+ "map" : 37.0 ,
336+ },
337+ )
338+ DEFAULT = COCO_V1
339+
340+
341+ class FasterRCNN_MobileNet_V3_Large_FPN_Weights (WeightsEnum ):
342+ COCO_V1 = Weights (
343+ url = "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth" ,
344+ transforms = ObjectDetectionEval ,
345+ meta = {
346+ ** _COMMON_META ,
347+ "num_params" : 19386354 ,
348+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn" ,
349+ "map" : 32.8 ,
350+ },
351+ )
352+ DEFAULT = COCO_V1
353+
354+
355+ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights (WeightsEnum ):
356+ COCO_V1 = Weights (
357+ url = "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth" ,
358+ transforms = ObjectDetectionEval ,
359+ meta = {
360+ ** _COMMON_META ,
361+ "num_params" : 19386354 ,
362+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn" ,
363+ "map" : 22.8 ,
364+ },
365+ )
366+ DEFAULT = COCO_V1
367+
368+
369+ @handle_legacy_interface (
370+ weights = ("pretrained" , FasterRCNN_ResNet50_FPN_Weights .COCO_V1 ),
371+ weights_backbone = ("pretrained_backbone" , ResNet50_Weights .IMAGENET1K_V1 ),
372+ )
317373def fasterrcnn_resnet50_fpn (
318- pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True , trainable_backbone_layers = None , ** kwargs
319- ):
374+ * ,
375+ weights : Optional [FasterRCNN_ResNet50_FPN_Weights ] = None ,
376+ progress : bool = True ,
377+ num_classes : Optional [int ] = None ,
378+ weights_backbone : Optional [ResNet50_Weights ] = None ,
379+ trainable_backbone_layers : Optional [int ] = None ,
380+ ** kwargs : Any ,
381+ ) -> FasterRCNN :
320382 """
321383 Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
322384
@@ -375,51 +437,60 @@ def fasterrcnn_resnet50_fpn(
375437 >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
376438
377439 Args:
378- pretrained (bool ): If True, returns a model pre-trained on COCO train2017
440+ weights (FasterRCNN_ResNet50_FPN_Weights, optional ): The pretrained weights for the model
379441 progress (bool): If True, displays a progress bar of the download to stderr
380- num_classes (int): number of output classes of the model (including the background)
381- pretrained_backbone (bool ): If True, returns a model with backbone pre-trained on Imagenet
382- trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
442+ num_classes (int, optional ): number of output classes of the model (including the background)
443+ weights_backbone (ResNet50_Weights, optional ): The pretrained weights for the backbone
444+ trainable_backbone_layers (int, optional ): number of trainable (not frozen) layers starting from final block.
383445 Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
384446 passed (the default) this value is set to 3.
385447 """
386- is_trained = pretrained or pretrained_backbone
448+ weights = FasterRCNN_ResNet50_FPN_Weights .verify (weights )
449+ weights_backbone = ResNet50_Weights .verify (weights_backbone )
450+
451+ if weights is not None :
452+ weights_backbone = None
453+ num_classes = _ovewrite_value_param (num_classes , len (weights .meta ["categories" ]))
454+ elif num_classes is None :
455+ num_classes = 91
456+
457+ is_trained = weights is not None or weights_backbone is not None
387458 trainable_backbone_layers = _validate_trainable_layers (is_trained , trainable_backbone_layers , 5 , 3 )
388459 norm_layer = misc_nn_ops .FrozenBatchNorm2d if is_trained else nn .BatchNorm2d
389460
390- if pretrained :
391- # no need to download the backbone if pretrained is set
392- pretrained_backbone = False
393-
394- backbone = resnet50 (pretrained = pretrained_backbone , progress = progress , norm_layer = norm_layer )
461+ backbone = resnet50 (weights = weights_backbone , progress = progress , norm_layer = norm_layer )
395462 backbone = _resnet_fpn_extractor (backbone , trainable_backbone_layers )
396- model = FasterRCNN (backbone , num_classes , ** kwargs )
397- if pretrained :
398- state_dict = load_state_dict_from_url (model_urls ["fasterrcnn_resnet50_fpn_coco" ], progress = progress )
399- model .load_state_dict (state_dict )
400- overwrite_eps (model , 0.0 )
463+ model = FasterRCNN (backbone , num_classes = num_classes , ** kwargs )
464+
465+ if weights is not None :
466+ model .load_state_dict (weights .get_state_dict (progress = progress ))
467+ if weights == FasterRCNN_ResNet50_FPN_Weights .COCO_V1 :
468+ overwrite_eps (model , 0.0 )
469+
401470 return model
402471
403472
404473def _fasterrcnn_mobilenet_v3_large_fpn (
405- weights_name ,
406- pretrained = False ,
407- progress = True ,
408- num_classes = 91 ,
409- pretrained_backbone = True ,
410- trainable_backbone_layers = None ,
411- ** kwargs ,
412- ):
413- is_trained = pretrained or pretrained_backbone
474+ * ,
475+ weights : Optional [Union [FasterRCNN_MobileNet_V3_Large_FPN_Weights , FasterRCNN_MobileNet_V3_Large_320_FPN_Weights ]],
476+ progress : bool ,
477+ num_classes : Optional [int ],
478+ weights_backbone : Optional [MobileNet_V3_Large_Weights ],
479+ trainable_backbone_layers : Optional [int ],
480+ ** kwargs : Any ,
481+ ) -> FasterRCNN :
482+ if weights is not None :
483+ weights_backbone = None
484+ num_classes = _ovewrite_value_param (num_classes , len (weights .meta ["categories" ]))
485+ elif num_classes is None :
486+ num_classes = 91
487+
488+ is_trained = weights is not None or weights_backbone is not None
414489 trainable_backbone_layers = _validate_trainable_layers (is_trained , trainable_backbone_layers , 6 , 3 )
415490 norm_layer = misc_nn_ops .FrozenBatchNorm2d if is_trained else nn .BatchNorm2d
416491
417- if pretrained :
418- pretrained_backbone = False
419-
420- backbone = mobilenet_v3_large (pretrained = pretrained_backbone , progress = progress , norm_layer = norm_layer )
492+ backbone = mobilenet_v3_large (weights = weights_backbone , progress = progress , norm_layer = norm_layer )
421493 backbone = _mobilenet_extractor (backbone , True , trainable_backbone_layers )
422-
423494 anchor_sizes = (
424495 (
425496 32 ,
@@ -430,21 +501,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
430501 ),
431502 ) * 3
432503 aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
433-
434504 model = FasterRCNN (
435505 backbone , num_classes , rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios ), ** kwargs
436506 )
437- if pretrained :
438- if model_urls .get (weights_name , None ) is None :
439- raise ValueError (f"No checkpoint is available for model { weights_name } " )
440- state_dict = load_state_dict_from_url (model_urls [weights_name ], progress = progress )
441- model .load_state_dict (state_dict )
507+
508+ if weights is not None :
509+ model .load_state_dict (weights .get_state_dict (progress = progress ))
510+
442511 return model
443512
444513
514+ @handle_legacy_interface (
515+ weights = ("pretrained" , FasterRCNN_MobileNet_V3_Large_320_FPN_Weights .COCO_V1 ),
516+ weights_backbone = ("pretrained_backbone" , MobileNet_V3_Large_Weights .IMAGENET1K_V1 ),
517+ )
445518def fasterrcnn_mobilenet_v3_large_320_fpn (
446- pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True , trainable_backbone_layers = None , ** kwargs
447- ):
519+ * ,
520+ weights : Optional [FasterRCNN_MobileNet_V3_Large_320_FPN_Weights ] = None ,
521+ progress : bool = True ,
522+ num_classes : Optional [int ] = None ,
523+ weights_backbone : Optional [MobileNet_V3_Large_Weights ] = None ,
524+ trainable_backbone_layers : Optional [int ] = None ,
525+ ** kwargs : Any ,
526+ ) -> FasterRCNN :
448527 """
449528 Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
450529 It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
@@ -459,15 +538,17 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
459538 >>> predictions = model(x)
460539
461540 Args:
462- pretrained (bool ): If True, returns a model pre-trained on COCO train2017
541+ weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional ): The pretrained weights for the model
463542 progress (bool): If True, displays a progress bar of the download to stderr
464- num_classes (int): number of output classes of the model (including the background)
465- pretrained_backbone (bool ): If True, returns a model with backbone pre-trained on Imagenet
466- trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
543+ num_classes (int, optional ): number of output classes of the model (including the background)
544+ weights_backbone (MobileNet_V3_Large_Weights, optional ): The pretrained weights for the backbone
545+ trainable_backbone_layers (int, optional ): number of trainable (not frozen) layers starting from final block.
467546 Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
468547 passed (the default) this value is set to 3.
469548 """
470- weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco"
549+ weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights .verify (weights )
550+ weights_backbone = MobileNet_V3_Large_Weights .verify (weights_backbone )
551+
471552 defaults = {
472553 "min_size" : 320 ,
473554 "max_size" : 640 ,
@@ -478,19 +559,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
478559
479560 kwargs = {** defaults , ** kwargs }
480561 return _fasterrcnn_mobilenet_v3_large_fpn (
481- weights_name ,
482- pretrained = pretrained ,
562+ weights = weights ,
483563 progress = progress ,
484564 num_classes = num_classes ,
485- pretrained_backbone = pretrained_backbone ,
565+ weights_backbone = weights_backbone ,
486566 trainable_backbone_layers = trainable_backbone_layers ,
487567 ** kwargs ,
488568 )
489569
490570
571+ @handle_legacy_interface (
572+ weights = ("pretrained" , FasterRCNN_MobileNet_V3_Large_FPN_Weights .COCO_V1 ),
573+ weights_backbone = ("pretrained_backbone" , MobileNet_V3_Large_Weights .IMAGENET1K_V1 ),
574+ )
491575def fasterrcnn_mobilenet_v3_large_fpn (
492- pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True , trainable_backbone_layers = None , ** kwargs
493- ):
576+ * ,
577+ weights : Optional [FasterRCNN_MobileNet_V3_Large_FPN_Weights ] = None ,
578+ progress : bool = True ,
579+ num_classes : Optional [int ] = None ,
580+ weights_backbone : Optional [MobileNet_V3_Large_Weights ] = None ,
581+ trainable_backbone_layers : Optional [int ] = None ,
582+ ** kwargs : Any ,
583+ ) -> FasterRCNN :
494584 """
495585 Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
496586 It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
@@ -505,26 +595,27 @@ def fasterrcnn_mobilenet_v3_large_fpn(
505595 >>> predictions = model(x)
506596
507597 Args:
508- pretrained (bool ): If True, returns a model pre-trained on COCO train2017
598+ weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional ): The pretrained weights for the model
509599 progress (bool): If True, displays a progress bar of the download to stderr
510- num_classes (int): number of output classes of the model (including the background)
511- pretrained_backbone (bool ): If True, returns a model with backbone pre-trained on Imagenet
512- trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
600+ num_classes (int, optional ): number of output classes of the model (including the background)
601+ weights_backbone (MobileNet_V3_Large_Weights, optional ): The pretrained weights for the backbone
602+ trainable_backbone_layers (int, optional ): number of trainable (not frozen) layers starting from final block.
513603 Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
514604 passed (the default) this value is set to 3.
515605 """
516- weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco"
606+ weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights .verify (weights )
607+ weights_backbone = MobileNet_V3_Large_Weights .verify (weights_backbone )
608+
517609 defaults = {
518610 "rpn_score_thresh" : 0.05 ,
519611 }
520612
521613 kwargs = {** defaults , ** kwargs }
522614 return _fasterrcnn_mobilenet_v3_large_fpn (
523- weights_name ,
524- pretrained = pretrained ,
615+ weights = weights ,
525616 progress = progress ,
526617 num_classes = num_classes ,
527- pretrained_backbone = pretrained_backbone ,
618+ weights_backbone = weights_backbone ,
528619 trainable_backbone_layers = trainable_backbone_layers ,
529620 ** kwargs ,
530621 )
0 commit comments