1010import torchvision .datasets as datasets
1111import torchvision .transforms as transforms
1212
13+ from accelerate import Accelerator
1314from plain_cnn_cifar import ConvNetMaker , plane_cifar100_book
1415
1516# used for logging to TensorBoard
6061 help = 'loss weights of distillation, should be a list of length 2, '
6162 'and sum to 1.0, first for student targets loss weight, '
6263 'second for teacher student loss weight.' )
64+ parser .add_argument ("--no_cuda" , action = 'store_true' , help = 'use cpu for training.' )
6365parser .set_defaults (augment = True )
6466
6567
@@ -75,10 +77,13 @@ def set_seed(seed):
7577def main ():
7678 global args , best_prec1
7779 args , _ = parser .parse_known_args ()
80+ accelerator = Accelerator (cpu = args .no_cuda )
81+
7882 best_prec1 = 0
7983 if args .seed is not None :
8084 set_seed (args .seed )
81- if args .tensorboard : configure ("runs/%s" % (args .name ))
85+ with accelerator .local_main_process_first ():
86+ if args .tensorboard : configure ("runs/%s" % (args .name ))
8287
8388 # Data loading code
8489 normalize = transforms .Normalize (mean = [0.5071 , 0.4866 , 0.4409 ], std = [0.2675 , 0.2565 , 0.2761 ])
@@ -121,9 +126,9 @@ def main():
121126 raise NotImplementedError ('Unsupported student model type' )
122127
123128 # get the number of model parameters
124- print ('Number of teacher model parameters: {}' .format (
129+ accelerator . print ('Number of teacher model parameters: {}' .format (
125130 sum ([p .data .nelement () for p in teacher_model .parameters ()])))
126- print ('Number of student model parameters: {}' .format (
131+ accelerator . print ('Number of student model parameters: {}' .format (
127132 sum ([p .data .nelement () for p in student_model .parameters ()])))
128133
129134 kwargs = {'num_workers' : 0 , 'pin_memory' : True }
@@ -135,10 +140,10 @@ def main():
135140 if args .loss_weights [1 ] > 0 :
136141 from tqdm import tqdm
137142 def get_logits (teacher_model , train_dataset ):
138- print ("***** Getting logits of teacher model *****" )
139- print (f" Num examples = { len (train_dataset ) } " )
143+ accelerator . print ("***** Getting logits of teacher model *****" )
144+ accelerator . print (f" Num examples = { len (train_dataset ) } " )
140145 logits_file = os .path .join (os .path .dirname (args .teacher_model ), 'teacher_logits.npy' )
141- if not os .path .exists (logits_file ):
146+ if not os .path .exists (logits_file ) and accelerator . is_local_main_process :
142147 teacher_model .eval ()
143148 train_dataloader = torch .utils .data .DataLoader (train_dataset , batch_size = args .batch_size , ** kwargs )
144149 train_dataloader = tqdm (train_dataloader , desc = "Evaluating" )
@@ -147,8 +152,8 @@ def get_logits(teacher_model, train_dataset):
147152 outputs = teacher_model (input )
148153 teacher_logits += [x for x in outputs .numpy ()]
149154 np .save (logits_file , np .array (teacher_logits ))
150- else :
151- teacher_logits = np .load (logits_file )
155+ accelerator . wait_for_everyone ()
156+ teacher_logits = np .load (logits_file )
152157 train_dataset .targets = [{'labels' :l , 'teacher_logits' :tl } \
153158 for l , tl in zip (train_dataset .targets , teacher_logits )]
154159 return train_dataset
@@ -163,29 +168,34 @@ def get_logits(teacher_model, train_dataset):
163168 # optionally resume from a checkpoint
164169 if args .resume :
165170 if os .path .isfile (args .resume ):
166- print ("=> loading checkpoint '{}'" .format (args .resume ))
171+ accelerator . print ("=> loading checkpoint '{}'" .format (args .resume ))
167172 checkpoint = torch .load (args .resume )
168173 args .start_epoch = checkpoint ['epoch' ]
169174 best_prec1 = checkpoint ['best_prec1' ]
170175 student_model .load_state_dict (checkpoint ['state_dict' ])
171- print ("=> loaded checkpoint '{}' (epoch {})"
176+ accelerator . print ("=> loaded checkpoint '{}' (epoch {})"
172177 .format (args .resume , checkpoint ['epoch' ]))
173178 else :
174- print ("=> no checkpoint found at '{}'" .format (args .resume ))
179+ accelerator . print ("=> no checkpoint found at '{}'" .format (args .resume ))
175180
176181 # define optimizer
177182 optimizer = torch .optim .SGD (student_model .parameters (), args .lr ,
178183 momentum = args .momentum , nesterov = args .nesterov ,
179184 weight_decay = args .weight_decay )
180185
181186 # cosine learning rate
182- scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer , len (train_loader )* args .epochs )
187+ scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (
188+ optimizer , len (train_loader ) * args .epochs // accelerator .num_processes
189+ )
190+
191+ student_model , teacher_model , train_loader , val_loader , optimizer = \
192+ accelerator .prepare (student_model , teacher_model , train_loader , val_loader , optimizer )
183193
184194 def train_func (model ):
185- return train (train_loader , model , scheduler , distiller , best_prec1 )
195+ return train (train_loader , model , scheduler , distiller , best_prec1 , accelerator )
186196
187197 def eval_func (model ):
188- return validate (val_loader , model , distiller )
198+ return validate (val_loader , model , distiller , accelerator )
189199
190200 from neural_compressor .experimental import Distillation , common
191201 from neural_compressor .experimental .common .criterion import PyTorchKnowledgeDistillationLoss
@@ -204,11 +214,12 @@ def eval_func(model):
204214
205215 directory = "runs/%s/" % (args .name )
206216 os .makedirs (directory , exist_ok = True )
217+ model ._model = accelerator .unwrap_model (model .model )
207218 model .save (directory )
208219 # change to framework model for further use
209220 model = model .model
210221
211- def train (train_loader , model , scheduler , distiller , best_prec1 ):
222+ def train (train_loader , model , scheduler , distiller , best_prec1 , accelerator ):
212223 distiller .on_train_begin ()
213224 for epoch in range (args .start_epoch , args .epochs ):
214225 """Train for one epoch on the training set"""
@@ -233,13 +244,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
233244 loss = distiller .on_after_compute_loss (input , output , loss , teacher_logits )
234245
235246 # measure accuracy and record loss
247+ output = accelerator .gather (output )
248+ target = accelerator .gather (target )
236249 prec1 = accuracy (output .data , target , topk = (1 ,))[0 ]
237- losses .update (loss . data .item (), input .size (0 ))
238- top1 .update (prec1 .item (), input .size (0 ))
250+ losses .update (accelerator . gather ( loss ). sum (). data .item (), input .size (0 )* accelerator . num_processes )
251+ top1 .update (prec1 .item (), input .size (0 )* accelerator . num_processes )
239252
240253 # compute gradient and do SGD step
241254 distiller .optimizer .zero_grad ()
242- loss .backward ()
255+ accelerator . backward ( loss ) # loss.backward()
243256 distiller .optimizer .step ()
244257 scheduler .step ()
245258
@@ -248,7 +261,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
248261 end = time .time ()
249262
250263 if i % args .print_freq == 0 :
251- print ('Epoch: [{0}][{1}/{2}]\t '
264+ accelerator . print ('Epoch: [{0}][{1}/{2}]\t '
252265 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t '
253266 'Loss {loss.val:.4f} ({loss.avg:.4f})\t '
254267 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t '
@@ -260,19 +273,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
260273 # remember best prec@1 and save checkpoint
261274 is_best = distiller .best_score > best_prec1
262275 best_prec1 = max (distiller .best_score , best_prec1 )
263- save_checkpoint ({
264- 'epoch' : distiller ._epoch_runned + 1 ,
265- 'state_dict' : model .state_dict (),
266- 'best_prec1' : best_prec1 ,
267- }, is_best )
268- # log to TensorBoard
269- if args .tensorboard :
270- log_value ('train_loss' , losses .avg , epoch )
271- log_value ('train_acc' , top1 .avg , epoch )
272- log_value ('learning_rate' , scheduler ._last_lr [0 ], epoch )
276+ if accelerator .is_local_main_process :
277+ save_checkpoint ({
278+ 'epoch' : distiller ._epoch_runned + 1 ,
279+ 'state_dict' : model .state_dict (),
280+ 'best_prec1' : best_prec1 ,
281+ }, is_best )
282+ # log to TensorBoard
283+ if args .tensorboard :
284+ log_value ('train_loss' , losses .avg , epoch )
285+ log_value ('train_acc' , top1 .avg , epoch )
286+ log_value ('learning_rate' , scheduler ._last_lr [0 ], epoch )
273287
274288
275- def validate (val_loader , model , distiller ):
289+ def validate (val_loader , model , distiller , accelerator ):
276290 """Perform validation on the validation set"""
277291 batch_time = AverageMeter ()
278292 top1 = AverageMeter ()
@@ -287,6 +301,8 @@ def validate(val_loader, model, distiller):
287301 output = model (input )
288302
289303 # measure accuracy
304+ output = accelerator .gather (output )
305+ target = accelerator .gather (target )
290306 prec1 = accuracy (output .data , target , topk = (1 ,))[0 ]
291307 top1 .update (prec1 .item (), input .size (0 ))
292308
@@ -295,15 +311,15 @@ def validate(val_loader, model, distiller):
295311 end = time .time ()
296312
297313 if i % args .print_freq == 0 :
298- print ('Test: [{0}/{1}]\t '
314+ accelerator . print ('Test: [{0}/{1}]\t '
299315 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t '
300316 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})' .format (
301317 i , len (val_loader ), batch_time = batch_time ,
302318 top1 = top1 ))
303319
304- print (' * Prec@1 {top1.avg:.3f}' .format (top1 = top1 ))
320+ accelerator . print (' * Prec@1 {top1.avg:.3f}' .format (top1 = top1 ))
305321 # log to TensorBoard
306- if args .tensorboard :
322+ if accelerator . is_local_main_process and args .tensorboard :
307323 log_value ('val_acc' , top1 .avg , distiller ._epoch_runned )
308324 return top1 .avg
309325
0 commit comments