1+ from torch import Tensor
12import torch .nn as nn
3+ from typing import Tuple , Optional , Callable , List , Type , Any , Union
24
35from ..._internally_replaced_utils import load_state_dict_from_url
46
1315
1416
1517class Conv3DSimple (nn .Conv3d ):
16- def __init__ (self ,
17- in_planes ,
18- out_planes ,
19- midplanes = None ,
20- stride = 1 ,
21- padding = 1 ):
18+ def __init__ (
19+ self ,
20+ in_planes : int ,
21+ out_planes : int ,
22+ midplanes : Optional [int ] = None ,
23+ stride : int = 1 ,
24+ padding : int = 1
25+ ) -> None :
2226
2327 super (Conv3DSimple , self ).__init__ (
2428 in_channels = in_planes ,
@@ -29,18 +33,20 @@ def __init__(self,
2933 bias = False )
3034
3135 @staticmethod
32- def get_downsample_stride (stride ) :
36+ def get_downsample_stride (stride : int ) -> Tuple [ int , int , int ] :
3337 return stride , stride , stride
3438
3539
3640class Conv2Plus1D (nn .Sequential ):
3741
38- def __init__ (self ,
39- in_planes ,
40- out_planes ,
41- midplanes ,
42- stride = 1 ,
43- padding = 1 ):
42+ def __init__ (
43+ self ,
44+ in_planes : int ,
45+ out_planes : int ,
46+ midplanes : int ,
47+ stride : int = 1 ,
48+ padding : int = 1
49+ ) -> None :
4450 super (Conv2Plus1D , self ).__init__ (
4551 nn .Conv3d (in_planes , midplanes , kernel_size = (1 , 3 , 3 ),
4652 stride = (1 , stride , stride ), padding = (0 , padding , padding ),
@@ -52,18 +58,20 @@ def __init__(self,
5258 bias = False ))
5359
5460 @staticmethod
55- def get_downsample_stride (stride ) :
61+ def get_downsample_stride (stride : int ) -> Tuple [ int , int , int ] :
5662 return stride , stride , stride
5763
5864
5965class Conv3DNoTemporal (nn .Conv3d ):
6066
61- def __init__ (self ,
62- in_planes ,
63- out_planes ,
64- midplanes = None ,
65- stride = 1 ,
66- padding = 1 ):
67+ def __init__ (
68+ self ,
69+ in_planes : int ,
70+ out_planes : int ,
71+ midplanes : Optional [int ] = None ,
72+ stride : int = 1 ,
73+ padding : int = 1
74+ ) -> None :
6775
6876 super (Conv3DNoTemporal , self ).__init__ (
6977 in_channels = in_planes ,
@@ -74,15 +82,22 @@ def __init__(self,
7482 bias = False )
7583
7684 @staticmethod
77- def get_downsample_stride (stride ) :
85+ def get_downsample_stride (stride : int ) -> Tuple [ int , int , int ] :
7886 return 1 , stride , stride
7987
8088
8189class BasicBlock (nn .Module ):
8290
8391 expansion = 1
8492
85- def __init__ (self , inplanes , planes , conv_builder , stride = 1 , downsample = None ):
93+ def __init__ (
94+ self ,
95+ inplanes : int ,
96+ planes : int ,
97+ conv_builder : Callable [..., nn .Module ],
98+ stride : int = 1 ,
99+ downsample : Optional [nn .Module ] = None ,
100+ ) -> None :
86101 midplanes = (inplanes * planes * 3 * 3 * 3 ) // (inplanes * 3 * 3 + 3 * planes )
87102
88103 super (BasicBlock , self ).__init__ ()
@@ -99,7 +114,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
99114 self .downsample = downsample
100115 self .stride = stride
101116
102- def forward (self , x ) :
117+ def forward (self , x : Tensor ) -> Tensor :
103118 residual = x
104119
105120 out = self .conv1 (x )
@@ -116,7 +131,14 @@ def forward(self, x):
116131class Bottleneck (nn .Module ):
117132 expansion = 4
118133
119- def __init__ (self , inplanes , planes , conv_builder , stride = 1 , downsample = None ):
134+ def __init__ (
135+ self ,
136+ inplanes : int ,
137+ planes : int ,
138+ conv_builder : Callable [..., nn .Module ],
139+ stride : int = 1 ,
140+ downsample : Optional [nn .Module ] = None ,
141+ ) -> None :
120142
121143 super (Bottleneck , self ).__init__ ()
122144 midplanes = (inplanes * planes * 3 * 3 * 3 ) // (inplanes * 3 * 3 + 3 * planes )
@@ -143,7 +165,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
143165 self .downsample = downsample
144166 self .stride = stride
145167
146- def forward (self , x ) :
168+ def forward (self , x : Tensor ) -> Tensor :
147169 residual = x
148170
149171 out = self .conv1 (x )
@@ -162,7 +184,7 @@ def forward(self, x):
162184class BasicStem (nn .Sequential ):
163185 """The default conv-batchnorm-relu stem
164186 """
165- def __init__ (self ):
187+ def __init__ (self ) -> None :
166188 super (BasicStem , self ).__init__ (
167189 nn .Conv3d (3 , 64 , kernel_size = (3 , 7 , 7 ), stride = (1 , 2 , 2 ),
168190 padding = (1 , 3 , 3 ), bias = False ),
@@ -173,7 +195,7 @@ def __init__(self):
173195class R2Plus1dStem (nn .Sequential ):
174196 """R(2+1)D stem is different than the default one as it uses separated 3D convolution
175197 """
176- def __init__ (self ):
198+ def __init__ (self ) -> None :
177199 super (R2Plus1dStem , self ).__init__ (
178200 nn .Conv3d (3 , 45 , kernel_size = (1 , 7 , 7 ),
179201 stride = (1 , 2 , 2 ), padding = (0 , 3 , 3 ),
@@ -189,16 +211,23 @@ def __init__(self):
189211
190212class VideoResNet (nn .Module ):
191213
192- def __init__ (self , block , conv_makers , layers ,
193- stem , num_classes = 400 ,
194- zero_init_residual = False ):
214+ def __init__ (
215+ self ,
216+ block : Type [Union [BasicBlock , Bottleneck ]],
217+ conv_makers : List [Type [Union [Conv3DSimple , Conv3DNoTemporal , Conv2Plus1D ]]],
218+ layers : List [int ],
219+ stem : Callable [..., nn .Module ],
220+ num_classes : int = 400 ,
221+ zero_init_residual : bool = False ,
222+ ) -> None :
195223 """Generic resnet video generator.
196224
197225 Args:
198- block (nn.Module): resnet building block
199- conv_makers (list(functions)): generator function for each layer
226+ block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
227+ conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
228+ function for each layer
200229 layers (List[int]): number of blocks per layer
201- stem (nn.Module, optional ): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None .
230+ stem (Callable[..., nn.Module] ): module specifying the ResNet stem .
202231 num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
203232 zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
204233 """
@@ -221,9 +250,9 @@ def __init__(self, block, conv_makers, layers,
221250 if zero_init_residual :
222251 for m in self .modules ():
223252 if isinstance (m , Bottleneck ):
224- nn .init .constant_ (m .bn3 .weight , 0 )
253+ nn .init .constant_ (m .bn3 .weight , 0 ) # type: ignore[union-attr, arg-type]
225254
226- def forward (self , x ) :
255+ def forward (self , x : Tensor ) -> Tensor :
227256 x = self .stem (x )
228257
229258 x = self .layer1 (x )
@@ -238,7 +267,14 @@ def forward(self, x):
238267
239268 return x
240269
241- def _make_layer (self , block , conv_builder , planes , blocks , stride = 1 ):
270+ def _make_layer (
271+ self ,
272+ block : Type [Union [BasicBlock , Bottleneck ]],
273+ conv_builder : Type [Union [Conv3DSimple , Conv3DNoTemporal , Conv2Plus1D ]],
274+ planes : int ,
275+ blocks : int ,
276+ stride : int = 1
277+ ) -> nn .Sequential :
242278 downsample = None
243279
244280 if stride != 1 or self .inplanes != planes * block .expansion :
@@ -257,7 +293,7 @@ def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
257293
258294 return nn .Sequential (* layers )
259295
260- def _initialize_weights (self ):
296+ def _initialize_weights (self ) -> None :
261297 for m in self .modules ():
262298 if isinstance (m , nn .Conv3d ):
263299 nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' ,
@@ -272,7 +308,7 @@ def _initialize_weights(self):
272308 nn .init .constant_ (m .bias , 0 )
273309
274310
275- def _video_resnet (arch , pretrained = False , progress = True , ** kwargs ) :
311+ def _video_resnet (arch : str , pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
276312 model = VideoResNet (** kwargs )
277313
278314 if pretrained :
@@ -282,7 +318,7 @@ def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
282318 return model
283319
284320
285- def r3d_18 (pretrained = False , progress = True , ** kwargs ) :
321+ def r3d_18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
286322 """Construct 18 layer Resnet3D model as in
287323 https://arxiv.org/abs/1711.11248
288324
@@ -302,7 +338,7 @@ def r3d_18(pretrained=False, progress=True, **kwargs):
302338 stem = BasicStem , ** kwargs )
303339
304340
305- def mc3_18 (pretrained = False , progress = True , ** kwargs ) :
341+ def mc3_18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
306342 """Constructor for 18 layer Mixed Convolution network as in
307343 https://arxiv.org/abs/1711.11248
308344
@@ -316,12 +352,12 @@ def mc3_18(pretrained=False, progress=True, **kwargs):
316352 return _video_resnet ('mc3_18' ,
317353 pretrained , progress ,
318354 block = BasicBlock ,
319- conv_makers = [Conv3DSimple ] + [Conv3DNoTemporal ] * 3 ,
355+ conv_makers = [Conv3DSimple ] + [Conv3DNoTemporal ] * 3 , # type: ignore[list-item]
320356 layers = [2 , 2 , 2 , 2 ],
321357 stem = BasicStem , ** kwargs )
322358
323359
324- def r2plus1d_18 (pretrained = False , progress = True , ** kwargs ) :
360+ def r2plus1d_18 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> VideoResNet :
325361 """Constructor for the 18 layer deep R(2+1)D network as in
326362 https://arxiv.org/abs/1711.11248
327363
0 commit comments