1717 amp = None
1818
1919
20- def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , print_freq , apex = False ):
20+ def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch ,
21+ print_freq , apex = False , model_ema = None ):
2122 model .train ()
2223 metric_logger = utils .MetricLogger (delimiter = " " )
2324 metric_logger .add_meter ('lr' , utils .SmoothedValue (window_size = 1 , fmt = '{value}' ))
@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
4546 metric_logger .meters ['acc5' ].update (acc5 .item (), n = batch_size )
4647 metric_logger .meters ['img/s' ].update (batch_size / (time .time () - start_time ))
4748
49+ if model_ema :
50+ model_ema .update_parameters (model )
4851
49- def evaluate (model , criterion , data_loader , device , print_freq = 100 ):
52+
53+ def evaluate (model , criterion , data_loader , device , print_freq = 100 , log_suffix = '' ):
5054 model .eval ()
5155 metric_logger = utils .MetricLogger (delimiter = " " )
52- header = 'Test:'
56+ header = f 'Test: { log_suffix } '
5357 with torch .no_grad ():
5458 for image , target in metric_logger .log_every (data_loader , print_freq , header ):
5559 image = image .to (device , non_blocking = True )
@@ -199,12 +203,18 @@ def main(args):
199203 model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
200204 model_without_ddp = model .module
201205
206+ model_ema = None
207+ if args .model_ema :
208+ model_ema = utils .ExponentialMovingAverage (model_without_ddp , device = device , decay = args .model_ema_decay )
209+
202210 if args .resume :
203211 checkpoint = torch .load (args .resume , map_location = 'cpu' )
204212 model_without_ddp .load_state_dict (checkpoint ['model' ])
205213 optimizer .load_state_dict (checkpoint ['optimizer' ])
206214 lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
207215 args .start_epoch = checkpoint ['epoch' ] + 1
216+ if model_ema :
217+ model_ema .load_state_dict (checkpoint ['model_ema' ])
208218
209219 if args .test_only :
210220 evaluate (model , criterion , data_loader_test , device = device )
@@ -215,16 +225,20 @@ def main(args):
215225 for epoch in range (args .start_epoch , args .epochs ):
216226 if args .distributed :
217227 train_sampler .set_epoch (epoch )
218- train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .apex )
228+ train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .apex , model_ema )
219229 lr_scheduler .step ()
220230 evaluate (model , criterion , data_loader_test , device = device )
231+ if model_ema :
232+ evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = 'EMA' )
221233 if args .output_dir :
222234 checkpoint = {
223235 'model' : model_without_ddp .state_dict (),
224236 'optimizer' : optimizer .state_dict (),
225237 'lr_scheduler' : lr_scheduler .state_dict (),
226238 'epoch' : epoch ,
227239 'args' : args }
240+ if model_ema :
241+ checkpoint ['model_ema' ] = model_ema .state_dict ()
228242 utils .save_on_master (
229243 checkpoint ,
230244 os .path .join (args .output_dir , 'model_{}.pth' .format (epoch )))
@@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
306320 parser .add_argument ('--world-size' , default = 1 , type = int ,
307321 help = 'number of distributed processes' )
308322 parser .add_argument ('--dist-url' , default = 'env://' , help = 'url used to set up distributed training' )
323+ parser .add_argument (
324+ '--model-ema' , action = 'store_true' ,
325+ help = 'enable tracking Exponential Moving Average of model parameters' )
326+ parser .add_argument (
327+ '--model-ema-decay' , type = float , default = 0.99 ,
328+ help = 'decay factor for Exponential Moving Average of model parameters(default: 0.99)' )
309329
310330 return parser
311331
0 commit comments