77from enum import Enum
88
99import torch
10- import torch .nn as nn
11- import torch .nn .parallel
1210import torch .backends .cudnn as cudnn
1311import torch .distributed as dist
14- import torch .optim
15- from torch .optim .lr_scheduler import StepLR
1612import torch .multiprocessing as mp
13+ import torch .nn as nn
14+ import torch .nn .parallel
15+ import torch .optim
1716import torch .utils .data
1817import torch .utils .data .distributed
19- import torchvision .transforms as transforms
2018import torchvision .datasets as datasets
2119import torchvision .models as models
20+ import torchvision .transforms as transforms
21+ from torch .optim .lr_scheduler import StepLR
2222from torch .utils .data import Subset
2323
2424model_names = sorted (name for name in models .__dict__
@@ -275,13 +275,12 @@ def main_worker(gpu, ngpus_per_node, args):
275275 train_sampler .set_epoch (epoch )
276276
277277 # train for one epoch
278- train (train_loader , model , criterion , optimizer , epoch , args )
278+ train (train_loader , model , criterion , optimizer , epoch , device , args )
279279
280280 # evaluate on validation set
281281 acc1 = validate (val_loader , model , criterion , args )
282282
283283 scheduler .step ()
284-
285284
286285 # remember best acc@1 and save checkpoint
287286 is_best = acc1 > best_acc1
@@ -299,7 +298,7 @@ def main_worker(gpu, ngpus_per_node, args):
299298 }, is_best )
300299
301300
302- def train (train_loader , model , criterion , optimizer , epoch , args ):
301+ def train (train_loader , model , criterion , optimizer , epoch , device , args ):
303302 batch_time = AverageMeter ('Time' , ':6.3f' )
304303 data_time = AverageMeter ('Data' , ':6.3f' )
305304 losses = AverageMeter ('Loss' , ':.4e' )
@@ -318,13 +317,9 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
318317 # measure data loading time
319318 data_time .update (time .time () - end )
320319
321- if args .gpu is not None and torch .cuda .is_available ():
322- images = images .cuda (args .gpu , non_blocking = True )
323- elif not args .gpu and torch .cuda .is_available ():
324- target = target .cuda (args .gpu , non_blocking = True )
325- elif torch .backends .mps .is_available ():
326- images = images .to ('mps' )
327- target = target .to ('mps' )
320+ # move data to the same device as model
321+ images = images .to (device , non_blocking = True )
322+ target = target .to (device , non_blocking = True )
328323
329324 # compute output
330325 output = model (images )
0 commit comments