1
- from collections import OrderedDict
2
-
3
1
import torch
4
2
from torch import nn
5
3
import torch .nn .functional as F
6
4
7
- from torchvision .ops import misc as misc_nn_ops
8
5
from torchvision .ops import MultiScaleRoIAlign
9
6
10
7
from ._utils import overwrite_eps
15
12
from .rpn import RPNHead , RegionProposalNetwork
16
13
from .roi_heads import RoIHeads
17
14
from .transform import GeneralizedRCNNTransform
18
- from .backbone_utils import resnet_fpn_backbone , _validate_resnet_trainable_layers
15
+ from .backbone_utils import resnet_fpn_backbone , _validate_trainable_layers , mobilenet_backbone
19
16
20
17
21
18
__all__ = [
22
- "FasterRCNN" , "fasterrcnn_resnet50_fpn" ,
19
+ "FasterRCNN" , "fasterrcnn_resnet50_fpn" , "fasterrcnn_mobilenet_v3_large_fpn"
23
20
]
24
21
25
22
@@ -291,6 +288,8 @@ def forward(self, x):
291
288
model_urls = {
292
289
'fasterrcnn_resnet50_fpn_coco' :
293
290
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth' ,
291
+ 'fasterrcnn_mobilenet_v3_large_fpn_coco' :
292
+ 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth' ,
294
293
}
295
294
296
295
@@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
353
352
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
354
353
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
355
354
"""
356
- # check default parameters and by default set it to 3 if possible
357
- trainable_backbone_layers = _validate_resnet_trainable_layers (
358
- pretrained or pretrained_backbone , trainable_backbone_layers )
355
+ trainable_backbone_layers = _validate_trainable_layers (
356
+ pretrained or pretrained_backbone , trainable_backbone_layers , 5 , 3 )
359
357
360
358
if pretrained :
361
359
# no need to download the backbone if pretrained is set
@@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
368
366
model .load_state_dict (state_dict )
369
367
overwrite_eps (model , 0.0 )
370
368
return model
369
+
370
+
371
+ def fasterrcnn_mobilenet_v3_large_fpn (pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True ,
372
+ trainable_backbone_layers = None , min_size = 320 , max_size = 640 , rpn_score_thresh = 0.05 ,
373
+ ** kwargs ):
374
+ """
375
+ Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
376
+ to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
377
+
378
+ Example::
379
+
380
+ >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
381
+ >>> model.eval()
382
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
383
+ >>> predictions = model(x)
384
+
385
+ Args:
386
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017
387
+ progress (bool): If True, displays a progress bar of the download to stderr
388
+ num_classes (int): number of output classes of the model (including the background)
389
+ pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
390
+ trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
391
+ Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
392
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
393
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
394
+ rpn_score_thresh (float): during inference, only return proposals with a classification score
395
+ greater than rpn_score_thresh
396
+ """
397
+ trainable_backbone_layers = _validate_trainable_layers (
398
+ pretrained or pretrained_backbone , trainable_backbone_layers , 6 , 3 )
399
+
400
+ if pretrained :
401
+ pretrained_backbone = False
402
+ backbone = mobilenet_backbone ("mobilenet_v3_large" , pretrained_backbone , True ,
403
+ trainable_layers = trainable_backbone_layers )
404
+
405
+ anchor_sizes = ((32 , 64 , 128 , 256 , 512 , ), ) * 3
406
+ aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
407
+
408
+ model = FasterRCNN (backbone , num_classes , rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios ),
409
+ min_size = min_size , max_size = max_size , rpn_score_thresh = rpn_score_thresh , ** kwargs )
410
+ if pretrained :
411
+ state_dict = load_state_dict_from_url (model_urls ['fasterrcnn_mobilenet_v3_large_fpn_coco' ], progress = progress )
412
+ model .load_state_dict (state_dict )
413
+ return model
0 commit comments