Skip to content

Commit 7b83e67

Browse files
committed
Pass drop connect arg through to EfficientNet models
1 parent 31453b0 commit 7b83e67

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

timm/models/factory.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ def create_model(
2525
"""
2626
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
2727

28-
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
29-
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet'])
30-
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
28+
# Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args
29+
is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet'])
30+
if not is_efficientnet:
3131
kwargs.pop('bn_tf', None)
3232
kwargs.pop('bn_momentum', None)
3333
kwargs.pop('bn_eps', None)
34+
kwargs.pop('drop_connect_rate', None)
3435

3536
if is_model(model_name):
3637
create_fn = model_entrypoint(model_name)

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
help='input batch size for training (default: 32)')
6666
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
6767
help='Dropout rate (default: 0.)')
68+
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
69+
help='Drop connect rate (default: 0.)')
6870
# Optimizer parameters
6971
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
7072
help='Optimizer (default: "sgd"')
@@ -208,6 +210,7 @@ def main():
208210
pretrained=args.pretrained,
209211
num_classes=args.num_classes,
210212
drop_rate=args.drop,
213+
drop_connect_rate=args.drop_connect,
211214
global_pool=args.gp,
212215
bn_tf=args.bn_tf,
213216
bn_momentum=args.bn_momentum,
@@ -253,7 +256,7 @@ def main():
253256
if args.local_rank == 0:
254257
logging.info('Restoring NVIDIA AMP state from checkpoint')
255258
amp.load_state_dict(resume_state['amp'])
256-
resume_state = None
259+
resume_state = None # clear it
257260

258261
model_ema = None
259262
if args.model_ema:

0 commit comments

Comments
 (0)