1+ from torch import nn
2+ from typing import Any , Optional
13from .._utils import IntermediateLayerGetter
24from ..._internally_replaced_utils import load_state_dict_from_url
35from .. import mobilenetv3
2224}
2325
2426
25- def _segm_model (name , backbone_name , num_classes , aux , pretrained_backbone = True ):
27+ def _segm_model (
28+ name : str ,
29+ backbone_name : str ,
30+ num_classes : int ,
31+ aux : Optional [bool ],
32+ pretrained_backbone : bool = True
33+ ) -> nn .Module :
2634 if 'resnet' in backbone_name :
2735 backbone = resnet .__dict__ [backbone_name ](
2836 pretrained = pretrained_backbone ,
@@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
6674 return model
6775
6876
69- def _load_model (arch_type , backbone , pretrained , progress , num_classes , aux_loss , ** kwargs ):
77+ def _load_model (
78+ arch_type : str ,
79+ backbone : str ,
80+ pretrained : bool ,
81+ progress : bool ,
82+ num_classes : int ,
83+ aux_loss : Optional [bool ],
84+ ** kwargs : Any
85+ ) -> nn .Module :
7086 if pretrained :
7187 aux_loss = True
7288 kwargs ["pretrained_backbone" ] = False
@@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
7692 return model
7793
7894
79- def _load_weights (model , arch_type , backbone , progress ) :
95+ def _load_weights (model : nn . Module , arch_type : str , backbone : str , progress : bool ) -> None :
8096 arch = arch_type + '_' + backbone + '_coco'
8197 model_url = model_urls .get (arch , None )
8298 if model_url is None :
@@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress):
86102 model .load_state_dict (state_dict )
87103
88104
89- def _segm_lraspp_mobilenetv3 (backbone_name , num_classes , pretrained_backbone = True ):
105+ def _segm_lraspp_mobilenetv3 (backbone_name : str , num_classes : int , pretrained_backbone : bool = True ) -> LRASPP :
90106 backbone = mobilenetv3 .__dict__ [backbone_name ](pretrained = pretrained_backbone , dilated = True ).features
91107
92108 # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
@@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru
103119 return model
104120
105121
106- def fcn_resnet50 (pretrained = False , progress = True ,
107- num_classes = 21 , aux_loss = None , ** kwargs ):
122+ def fcn_resnet50 (
123+ pretrained : bool = False ,
124+ progress : bool = True ,
125+ num_classes : int = 21 ,
126+ aux_loss : Optional [bool ] = None ,
127+ ** kwargs : Any
128+ ) -> nn .Module :
108129 """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
109130
110131 Args:
@@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True,
117138 return _load_model ('fcn' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
118139
119140
120- def fcn_resnet101 (pretrained = False , progress = True ,
121- num_classes = 21 , aux_loss = None , ** kwargs ):
141+ def fcn_resnet101 (
142+ pretrained : bool = False ,
143+ progress : bool = True ,
144+ num_classes : int = 21 ,
145+ aux_loss : Optional [bool ] = None ,
146+ ** kwargs : Any
147+ ) -> nn .Module :
122148 """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
123149
124150 Args:
@@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True,
131157 return _load_model ('fcn' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
132158
133159
134- def deeplabv3_resnet50 (pretrained = False , progress = True ,
135- num_classes = 21 , aux_loss = None , ** kwargs ):
160+ def deeplabv3_resnet50 (
161+ pretrained : bool = False ,
162+ progress : bool = True ,
163+ num_classes : int = 21 ,
164+ aux_loss : Optional [bool ] = None ,
165+ ** kwargs : Any
166+ ) -> nn .Module :
136167 """Constructs a DeepLabV3 model with a ResNet-50 backbone.
137168
138169 Args:
@@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
145176 return _load_model ('deeplabv3' , 'resnet50' , pretrained , progress , num_classes , aux_loss , ** kwargs )
146177
147178
148- def deeplabv3_resnet101 (pretrained = False , progress = True ,
149- num_classes = 21 , aux_loss = None , ** kwargs ):
179+ def deeplabv3_resnet101 (
180+ pretrained : bool = False ,
181+ progress : bool = True ,
182+ num_classes : int = 21 ,
183+ aux_loss : Optional [bool ] = None ,
184+ ** kwargs : Any
185+ ) -> nn .Module :
150186 """Constructs a DeepLabV3 model with a ResNet-101 backbone.
151187
152188 Args:
@@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
159195 return _load_model ('deeplabv3' , 'resnet101' , pretrained , progress , num_classes , aux_loss , ** kwargs )
160196
161197
162- def deeplabv3_mobilenet_v3_large (pretrained = False , progress = True ,
163- num_classes = 21 , aux_loss = None , ** kwargs ):
198+ def deeplabv3_mobilenet_v3_large (
199+ pretrained : bool = False ,
200+ progress : bool = True ,
201+ num_classes : int = 21 ,
202+ aux_loss : Optional [bool ] = None ,
203+ ** kwargs : Any
204+ ) -> nn .Module :
164205 """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
165206
166207 Args:
@@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
173214 return _load_model ('deeplabv3' , 'mobilenet_v3_large' , pretrained , progress , num_classes , aux_loss , ** kwargs )
174215
175216
176- def lraspp_mobilenet_v3_large (pretrained = False , progress = True , num_classes = 21 , ** kwargs ):
217+ def lraspp_mobilenet_v3_large (
218+ pretrained : bool = False ,
219+ progress : bool = True ,
220+ num_classes : int = 21 ,
221+ ** kwargs : Any
222+ ) -> nn .Module :
177223 """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
178224
179225 Args:
0 commit comments