11"""PyTorch ResNet
2-
32This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
43additional dropout and dynamic global avg/max pool.
5-
64ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
75Copyright 2020 Ross Wightman
86"""
@@ -442,7 +440,7 @@ def drop_blocks(drop_block_rate=0.):
442440
443441def make_blocks (
444442 block_fn , channels , block_repeats , inplanes , reduce_first = 1 , output_stride = 32 ,
445- down_kernel_size = 1 , avg_down = False , drop_block_rate = 0. , drop_path_rate = 0. , first_conv_stride = 1 , ** kwargs ):
443+ down_kernel_size = 1 , avg_down = False , drop_block_rate = 0. , drop_path_rate = 0. , ** kwargs ):
446444 stages = []
447445 feature_info = []
448446 net_num_blocks = sum (block_repeats )
@@ -451,7 +449,7 @@ def make_blocks(
451449 dilation = prev_dilation = 1
452450 for stage_idx , (planes , num_blocks , db ) in enumerate (zip (channels , block_repeats , drop_blocks (drop_block_rate ))):
453451 stage_name = f'layer{ stage_idx + 1 } ' # never liked this name, but weight compat requires it
454- stride = first_conv_stride if stage_idx == 0 else 2
452+ stride = 1 if stage_idx == 0 else 2
455453 if net_stride >= output_stride :
456454 dilation *= stride
457455 stride = 1
@@ -494,7 +492,7 @@ class ResNet(nn.Module):
494492 This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
495493 variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
496494 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
497-
495+
498496 ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
499497 * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
500498 * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
@@ -503,18 +501,18 @@ class ResNet(nn.Module):
503501 * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
504502 * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
505503 * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
506-
504+
507505 ResNeXt
508506 * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
509507 * same c,d, e, s variants as ResNet can be enabled
510-
508+
511509 SE-ResNeXt
512510 * normal - 7x7 stem, stem_width = 64
513511 * same c, d, e, s variants as ResNet can be enabled
514-
512+
515513 SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
516514 reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
517-
515+
518516 Parameters
519517 ----------
520518 block : Block
@@ -558,12 +556,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
558556 cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = '' ,
559557 output_stride = 32 , block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False ,
560558 act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , aa_layer = None , drop_rate = 0.0 , drop_path_rate = 0. ,
561- drop_block_rate = 0. , global_pool = 'avg' , zero_init_last_bn = True , block_args = None , skip_stem_max_pool = False ):
559+ drop_block_rate = 0. , global_pool = 'avg' , zero_init_last_bn = True , block_args = None , replace_stem_max_pool = False ):
562560 block_args = block_args or dict ()
563561 assert output_stride in (8 , 16 , 32 )
564562 self .num_classes = num_classes
565563 self .drop_rate = drop_rate
566- self .skip_stem_max_pool = skip_stem_max_pool
564+ self .replace_stem_max_pool = replace_stem_max_pool
567565 super (ResNet , self ).__init__ ()
568566
569567 # Stem
@@ -588,25 +586,27 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
588586 self .feature_info = [dict (num_chs = inplanes , reduction = 2 , module = 'act1' )]
589587
590588 # Stem Pooling
591- if not self .skip_stem_max_pool :
592- first_conv_stride = 1
589+ if not self .replace_stem_max_pool :
593590 if aa_layer is not None :
594591 self .maxpool = nn .Sequential (* [
595592 nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 ),
596593 aa_layer (channels = inplanes , stride = 2 )])
597594 else :
598595 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
599596 else :
600- self .maxpool = nn .Identity ()
601- first_conv_stride = 2
597+ self .maxpool = nn .Sequential (* [
598+ nn .Conv2d (inplanes , inplanes , 3 , stride = 2 , padding = 1 ),
599+ nn .BatchNorm2d (inplanes ),
600+ nn .ReLU ()
601+ ])
602602
603603 # Feature Blocks
604604 channels = [64 , 128 , 256 , 512 ]
605605 stage_modules , stage_feature_info = make_blocks (
606606 block , channels , layers , inplanes , cardinality = cardinality , base_width = base_width ,
607607 output_stride = output_stride , reduce_first = block_reduce_first , avg_down = avg_down ,
608608 down_kernel_size = down_kernel_size , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer ,
609- drop_block_rate = drop_block_rate , drop_path_rate = drop_path_rate , first_conv_stride = first_conv_stride , ** block_args )
609+ drop_block_rate = drop_block_rate , drop_path_rate = drop_path_rate , ** block_args )
610610 for stage in stage_modules :
611611 self .add_module (* stage ) # layer1, layer2, etc
612612 self .feature_info .extend (stage_feature_info )
@@ -1078,39 +1078,39 @@ def ecaresnet50d(pretrained=False, **kwargs):
10781078@register_model
10791079def resnetrs50 (pretrained = False , ** kwargs ):
10801080 model_args = dict (
1081- block = Bottleneck , layers = [3 , 4 , 6 , 3 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1081+ block = Bottleneck , layers = [3 , 4 , 6 , 3 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
10821082 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
10831083 return _create_resnet ('resnetrs50' , pretrained , ** model_args )
10841084
10851085
10861086@register_model
10871087def resnetrs101 (pretrained = False , ** kwargs ):
10881088 model_args = dict (
1089- block = Bottleneck , layers = [3 , 4 , 23 , 3 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1089+ block = Bottleneck , layers = [3 , 4 , 23 , 3 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
10901090 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
10911091 return _create_resnet ('resnetrs101' , pretrained , ** model_args )
10921092
10931093
10941094@register_model
10951095def resnetrs152 (pretrained = False , ** kwargs ):
10961096 model_args = dict (
1097- block = Bottleneck , layers = [3 , 8 , 36 , 3 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1097+ block = Bottleneck , layers = [3 , 8 , 36 , 3 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
10981098 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
10991099 return _create_resnet ('resnetrs152' , pretrained , ** model_args )
11001100
11011101
11021102@register_model
11031103def resnetrs200 (pretrained = False , ** kwargs ):
11041104 model_args = dict (
1105- block = Bottleneck , layers = [3 , 24 , 36 , 3 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1105+ block = Bottleneck , layers = [3 , 24 , 36 , 3 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
11061106 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
11071107 return _create_resnet ('resnetrs200' , pretrained , ** model_args )
11081108
11091109
11101110@register_model
11111111def resnetrs270 (pretrained = False , ** kwargs ):
11121112 model_args = dict (
1113- block = Bottleneck , layers = [4 , 29 , 53 , 4 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1113+ block = Bottleneck , layers = [4 , 29 , 53 , 4 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
11141114 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
11151115 return _create_resnet ('resnetrs270' , pretrained , ** model_args )
11161116
@@ -1119,15 +1119,15 @@ def resnetrs270(pretrained=False, **kwargs):
11191119@register_model
11201120def resnetrs350 (pretrained = False , ** kwargs ):
11211121 model_args = dict (
1122- block = Bottleneck , layers = [4 , 36 , 72 , 4 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1122+ block = Bottleneck , layers = [4 , 36 , 72 , 4 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
11231123 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
11241124 return _create_resnet ('resnetrs350' , pretrained , ** model_args )
11251125
11261126
11271127@register_model
11281128def resnetrs420 (pretrained = False , ** kwargs ):
11291129 model_args = dict (
1130- block = Bottleneck , layers = [4 , 44 , 87 , 4 ], stem_width = 32 , stem_type = 'deep' , skip_stem_max_pool = True ,
1130+ block = Bottleneck , layers = [4 , 44 , 87 , 4 ], stem_width = 32 , stem_type = 'deep' , replace_stem_max_pool = True ,
11311131 avg_down = True , block_args = dict (attn_layer = 'se' ), ** kwargs )
11321132 return _create_resnet ('resnetrs420' , pretrained , ** model_args )
11331133
@@ -1373,4 +1373,4 @@ def senet154(pretrained=False, **kwargs):
13731373 model_args = dict (
13741374 block = Bottleneck , layers = [3 , 8 , 36 , 3 ], cardinality = 64 , base_width = 4 , stem_type = 'deep' ,
13751375 down_kernel_size = 3 , block_reduce_first = 2 , block_args = dict (attn_layer = 'se' ), ** kwargs )
1376- return _create_resnet ('senet154' , pretrained , ** model_args )
1376+ return _create_resnet ('senet154' , pretrained , ** model_args )
0 commit comments