@@ -276,12 +276,13 @@ class _BlockBuilder:
276276 """
277277
278278 def __init__ (self , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
279- act_fn = None , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
279+ drop_connect_rate = 0. , act_fn = None , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
280280 bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
281281 folded_bn = False , padding_same = False , verbose = False ):
282282 self .channel_multiplier = channel_multiplier
283283 self .channel_divisor = channel_divisor
284284 self .channel_min = channel_min
285+ self .drop_connect_rate = drop_connect_rate
285286 self .act_fn = act_fn
286287 self .se_gate_fn = se_gate_fn
287288 self .se_reduce_mid = se_reduce_mid
@@ -310,10 +311,12 @@ def _make_block(self, ba):
310311 print ('args:' , ba )
311312 # could replace this if with lambdas or functools binding if variety increases
312313 if bt == 'ir' :
314+ ba ['drop_connect_rate' ] = self .drop_connect_rate
313315 ba ['se_gate_fn' ] = self .se_gate_fn
314316 ba ['se_reduce_mid' ] = self .se_reduce_mid
315317 block = InvertedResidual (** ba )
316318 elif bt == 'ds' or bt == 'dsa' :
319+ ba ['drop_connect_rate' ] = self .drop_connect_rate
317320 block = DepthwiseSeparableConv (** ba )
318321 elif bt == 'ca' :
319322 block = CascadeConv (** ba )
@@ -402,6 +405,19 @@ def hard_sigmoid(x):
402405 return F .relu6 (x + 3. ) / 6.
403406
404407
408+ def drop_connect (inputs , training = False , drop_connect_rate = 0. ):
409+ """Apply drop connect."""
410+ if not training :
411+ return inputs
412+
413+ keep_prob = 1 - drop_connect_rate
414+ random_tensor = keep_prob + torch .rand (
415+ (inputs .size ()[0 ], 1 , 1 , 1 ), dtype = inputs .dtype , device = inputs .device )
416+ random_tensor .floor_ () # binarize
417+ output = inputs .div (keep_prob ) * random_tensor
418+ return output
419+
420+
405421class ChannelShuffle (nn .Module ):
406422 # FIXME haven't used yet
407423 def __init__ (self , groups ):
@@ -474,13 +490,14 @@ def __init__(self, in_chs, out_chs, kernel_size,
474490 stride = 1 , act_fn = F .relu , noskip = False , pw_act = False ,
475491 se_ratio = 0. , se_gate_fn = torch .sigmoid ,
476492 bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
477- folded_bn = False , padding_same = False ):
493+ folded_bn = False , padding_same = False , drop_connect_rate = 0. ):
478494 super (DepthwiseSeparableConv , self ).__init__ ()
479495 assert stride in [1 , 2 ]
480496 self .has_se = se_ratio is not None and se_ratio > 0.
481497 self .has_residual = (stride == 1 and in_chs == out_chs ) and not noskip
482498 self .has_pw_act = pw_act # activation after point-wise conv
483499 self .act_fn = act_fn
500+ self .drop_connect_rate = drop_connect_rate
484501 dw_padding = _padding_arg (kernel_size // 2 , padding_same )
485502 pw_padding = _padding_arg (0 , padding_same )
486503
@@ -515,7 +532,9 @@ def forward(self, x):
515532 x = self .act_fn (x )
516533
517534 if self .has_residual :
518- x += residual # FIXME add drop-connect
535+ if self .drop_connect_rate > 0. :
536+ x = drop_connect (x , self .training , self .drop_connect_rate )
537+ x += residual
519538 return x
520539
521540
@@ -557,12 +576,13 @@ def __init__(self, in_chs, out_chs, kernel_size,
557576 se_ratio = 0. , se_reduce_mid = False , se_gate_fn = torch .sigmoid ,
558577 shuffle_type = None , pw_group = 1 ,
559578 bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
560- folded_bn = False , padding_same = False ):
579+ folded_bn = False , padding_same = False , drop_connect_rate = 0. ):
561580 super (InvertedResidual , self ).__init__ ()
562581 mid_chs = int (in_chs * exp_ratio )
563582 self .has_se = se_ratio is not None and se_ratio > 0.
564583 self .has_residual = (in_chs == out_chs and stride == 1 ) and not noskip
565584 self .act_fn = act_fn
585+ self .drop_connect_rate = drop_connect_rate
566586 dw_padding = _padding_arg (kernel_size // 2 , padding_same )
567587 pw_padding = _padding_arg (0 , padding_same )
568588
@@ -619,7 +639,9 @@ def forward(self, x):
619639 x = self .bn3 (x )
620640
621641 if self .has_residual :
622- x += residual # FIXME add drop-connect
642+ if self .drop_connect_rate > 0. :
643+ x = drop_connect (x , self .training , self .drop_connect_rate )
644+ x += residual
623645
624646 # NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
625647
@@ -643,12 +665,14 @@ class GenMobileNet(nn.Module):
643665 def __init__ (self , block_args , num_classes = 1000 , in_chans = 3 , stem_size = 32 , num_features = 1280 ,
644666 channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
645667 bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
646- drop_rate = 0. , act_fn = F .relu , se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
668+ drop_rate = 0. , drop_connect_rate = 0. , act_fn = F .relu ,
669+ se_gate_fn = torch .sigmoid , se_reduce_mid = False ,
647670 global_pool = 'avg' , head_conv = 'default' , weight_init = 'goog' ,
648- folded_bn = False , padding_same = False ):
671+ folded_bn = False , padding_same = False , ):
649672 super (GenMobileNet , self ).__init__ ()
650673 self .num_classes = num_classes
651674 self .drop_rate = drop_rate
675+ self .drop_connect_rate = drop_connect_rate
652676 self .act_fn = act_fn
653677 self .num_features = num_features
654678
@@ -661,7 +685,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
661685
662686 builder = _BlockBuilder (
663687 channel_multiplier , channel_divisor , channel_min ,
664- act_fn , se_gate_fn , se_reduce_mid ,
688+ drop_connect_rate , act_fn , se_gate_fn , se_reduce_mid ,
665689 bn_momentum , bn_eps , folded_bn , padding_same , verbose = _DEBUG )
666690 self .blocks = nn .Sequential (* builder (in_chs , block_args ))
667691 in_chs = builder .in_chs
@@ -1090,7 +1114,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
10901114
10911115
10921116def _gen_efficientnet (channel_multiplier = 1.0 , depth_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
1093- """Creates a MobileNet-V3 model.
1117+ """Creates an EfficientNet model.
10941118
10951119 Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
10961120 Paper: https://arxiv.org/abs/1905.11946
@@ -1347,7 +1371,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
13471371def efficientnet_b0 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
13481372 """ EfficientNet """
13491373 default_cfg = default_cfgs ['efficientnet_b0' ]
1350- # NOTE dropout should be 0.2 for train
1374+ # NOTE for train, drop_rate should be 0.2
13511375 model = _gen_efficientnet (
13521376 channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
13531377 num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1360,7 +1384,7 @@ def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
13601384def efficientnet_b1 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
13611385 """ EfficientNet """
13621386 default_cfg = default_cfgs ['efficientnet_b1' ]
1363- # NOTE dropout should be 0.2 for train
1387+ # NOTE for train, drop_rate should be 0.2
13641388 model = _gen_efficientnet (
13651389 channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
13661390 num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1373,7 +1397,7 @@ def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
13731397def efficientnet_b2 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
13741398 """ EfficientNet """
13751399 default_cfg = default_cfgs ['efficientnet_b2' ]
1376- # NOTE dropout should be 0.3 for train
1400+ # NOTE for train, drop_rate should be 0.3
13771401 model = _gen_efficientnet (
13781402 channel_multiplier = 1.1 , depth_multiplier = 1.2 ,
13791403 num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1386,7 +1410,7 @@ def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
13861410def efficientnet_b3 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
13871411 """ EfficientNet """
13881412 default_cfg = default_cfgs ['efficientnet_b3' ]
1389- # NOTE dropout should be 0.3 for train
1413+ # NOTE for train, drop_rate should be 0.3
13901414 model = _gen_efficientnet (
13911415 channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
13921416 num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -1399,7 +1423,7 @@ def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
13991423def efficientnet_b4 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
14001424 """ EfficientNet """
14011425 default_cfg = default_cfgs ['efficientnet_b4' ]
1402- # NOTE dropout should be 0.4 for train
1426+ # NOTE for train, drop_rate should be 0.4
14031427 model = _gen_efficientnet (
14041428 channel_multiplier = 1.4 , depth_multiplier = 1.8 ,
14051429 num_classes = num_classes , in_chans = in_chans , ** kwargs )
0 commit comments