@@ -232,9 +232,9 @@ def main(args):
232232 custom_keys_weight_decay = []
233233 if args .bias_weight_decay is not None :
234234 custom_keys_weight_decay .append (("bias" , args .bias_weight_decay ))
235- if args .transformer_weight_decay is not None :
235+ if args .transformer_embedding_decay is not None :
236236 for key in ["class_token" , "position_embedding" , "relative_position_bias" ]:
237- custom_keys_weight_decay .append ((key , args .transformer_weight_decay ))
237+ custom_keys_weight_decay .append ((key , args .transformer_embedding_decay ))
238238 parameters = utils .set_weight_decay (
239239 model ,
240240 args .weight_decay ,
@@ -406,10 +406,10 @@ def get_args_parser(add_help=True):
406406 help = "weight decay for bias parameters of all layers (default: None, same value as --wd)" ,
407407 )
408408 parser .add_argument (
409- "--transformer-weight -decay" ,
409+ "--transformer-embedding -decay" ,
410410 default = None ,
411411 type = float ,
412- help = "weight decay for special parameters for vision transformer models (default: None, same value as --wd)" ,
412+ help = "weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)" ,
413413 )
414414 parser .add_argument (
415415 "--label-smoothing" , default = 0.0 , type = float , help = "label smoothing (default: 0.0)" , dest = "label_smoothing"
0 commit comments