11import warnings
2- from typing import List , Optional
2+ from typing import Callable , Dict , Optional , List
33
4- from torch import nn
4+ from torch import nn , Tensor
55from torchvision .ops import misc as misc_nn_ops
66from torchvision .ops .feature_pyramid_network import FeaturePyramidNetwork , LastLevelMaxPool , ExtraFPNBlock
77
@@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
2929 out_channels (int): the number of channels in the FPN
3030 """
3131
32- def __init__ (self , backbone , return_layers , in_channels_list , out_channels , extra_blocks = None ):
32+ def __init__ (
33+ self ,
34+ backbone : nn .Module ,
35+ return_layers : Dict [str , str ],
36+ in_channels_list : List [int ],
37+ out_channels : int ,
38+ extra_blocks : Optional [ExtraFPNBlock ] = None ,
39+ ) -> None :
3340 super (BackboneWithFPN , self ).__init__ ()
3441
3542 if extra_blocks is None :
@@ -43,20 +50,20 @@ def __init__(self, backbone, return_layers, in_channels_list, out_channels, extr
4350 )
4451 self .out_channels = out_channels
4552
46- def forward (self , x ) :
53+ def forward (self , x : Tensor ) -> Dict [ str , Tensor ] :
4754 x = self .body (x )
4855 x = self .fpn (x )
4956 return x
5057
5158
5259def resnet_fpn_backbone (
53- backbone_name ,
54- pretrained ,
55- norm_layer = misc_nn_ops .FrozenBatchNorm2d ,
56- trainable_layers = 3 ,
57- returned_layers = None ,
58- extra_blocks = None ,
59- ):
60+ backbone_name : str ,
61+ pretrained : bool ,
62+ norm_layer : Callable [..., nn . Module ] = misc_nn_ops .FrozenBatchNorm2d ,
63+ trainable_layers : int = 3 ,
64+ returned_layers : Optional [ List [ int ]] = None ,
65+ extra_blocks : Optional [ ExtraFPNBlock ] = None ,
66+ ) -> BackboneWithFPN :
6067 """
6168 Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
6269
@@ -80,7 +87,7 @@ def resnet_fpn_backbone(
8087 backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
8188 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
8289 pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
83- norm_layer (torchvision.ops ): it is recommended to use the default value. For details visit:
90+ norm_layer (callable ): it is recommended to use the default value. For details visit:
8491 (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
8592 trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
8693 Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
@@ -101,7 +108,8 @@ def _resnet_backbone_config(
101108 trainable_layers : int ,
102109 returned_layers : Optional [List [int ]],
103110 extra_blocks : Optional [ExtraFPNBlock ],
104- ):
111+ ) -> BackboneWithFPN :
112+
105113 # select layers that wont be frozen
106114 assert 0 <= trainable_layers <= 5
107115 layers_to_train = ["layer4" , "layer3" , "layer2" , "layer1" , "conv1" ][:trainable_layers ]
@@ -125,8 +133,13 @@ def _resnet_backbone_config(
125133 return BackboneWithFPN (backbone , return_layers , in_channels_list , out_channels , extra_blocks = extra_blocks )
126134
127135
128- def _validate_trainable_layers (pretrained , trainable_backbone_layers , max_value , default_value ):
129- # dont freeze any layers if pretrained model or backbone is not used
136+ def _validate_trainable_layers (
137+ pretrained : bool ,
138+ trainable_backbone_layers : Optional [int ],
139+ max_value : int ,
140+ default_value : int ,
141+ ) -> int :
142+ # don't freeze any layers if pretrained model or backbone is not used
130143 if not pretrained :
131144 if trainable_backbone_layers is not None :
132145 warnings .warn (
@@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
144157
145158
146159def mobilenet_backbone (
147- backbone_name ,
148- pretrained ,
149- fpn ,
150- norm_layer = misc_nn_ops .FrozenBatchNorm2d ,
151- trainable_layers = 2 ,
152- returned_layers = None ,
153- extra_blocks = None ,
154- ):
160+ backbone_name : str ,
161+ pretrained : bool ,
162+ fpn : bool ,
163+ norm_layer : Callable [..., nn .Module ] = misc_nn_ops .FrozenBatchNorm2d ,
164+ trainable_layers : int = 2 ,
165+ returned_layers : Optional [List [int ]] = None ,
166+ extra_blocks : Optional [ExtraFPNBlock ] = None ,
167+ ) -> nn .Module :
168+
155169 backbone = mobilenet .__dict__ [backbone_name ](pretrained = pretrained , norm_layer = norm_layer ).features
156170
157171 # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -185,5 +199,5 @@ def mobilenet_backbone(
185199 # depthwise linear combination of channels to reduce their size
186200 nn .Conv2d (backbone [- 1 ].out_channels , out_channels , 1 ),
187201 )
188- m .out_channels = out_channels
202+ m .out_channels = out_channels # type: ignore[assignment]
189203 return m
0 commit comments