Skip to content

Commit f5bb60f

Browse files
authored
Fix device mismatch issue in pytorch#1071 (pytorch#1073)
* fix device mismatch issue pytorch#1071
1 parent 35eb814 commit f5bb60f

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

imagenet/main.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
from enum import Enum
88

99
import torch
10-
import torch.nn as nn
11-
import torch.nn.parallel
1210
import torch.backends.cudnn as cudnn
1311
import torch.distributed as dist
14-
import torch.optim
15-
from torch.optim.lr_scheduler import StepLR
1612
import torch.multiprocessing as mp
13+
import torch.nn as nn
14+
import torch.nn.parallel
15+
import torch.optim
1716
import torch.utils.data
1817
import torch.utils.data.distributed
19-
import torchvision.transforms as transforms
2018
import torchvision.datasets as datasets
2119
import torchvision.models as models
20+
import torchvision.transforms as transforms
21+
from torch.optim.lr_scheduler import StepLR
2222
from torch.utils.data import Subset
2323

2424
model_names = sorted(name for name in models.__dict__
@@ -275,13 +275,12 @@ def main_worker(gpu, ngpus_per_node, args):
275275
train_sampler.set_epoch(epoch)
276276

277277
# train for one epoch
278-
train(train_loader, model, criterion, optimizer, epoch, args)
278+
train(train_loader, model, criterion, optimizer, epoch, device, args)
279279

280280
# evaluate on validation set
281281
acc1 = validate(val_loader, model, criterion, args)
282282

283283
scheduler.step()
284-
285284

286285
# remember best acc@1 and save checkpoint
287286
is_best = acc1 > best_acc1
@@ -299,7 +298,7 @@ def main_worker(gpu, ngpus_per_node, args):
299298
}, is_best)
300299

301300

302-
def train(train_loader, model, criterion, optimizer, epoch, args):
301+
def train(train_loader, model, criterion, optimizer, epoch, device, args):
303302
batch_time = AverageMeter('Time', ':6.3f')
304303
data_time = AverageMeter('Data', ':6.3f')
305304
losses = AverageMeter('Loss', ':.4e')
@@ -318,13 +317,9 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
318317
# measure data loading time
319318
data_time.update(time.time() - end)
320319

321-
if args.gpu is not None and torch.cuda.is_available():
322-
images = images.cuda(args.gpu, non_blocking=True)
323-
elif not args.gpu and torch.cuda.is_available():
324-
target = target.cuda(args.gpu, non_blocking=True)
325-
elif torch.backends.mps.is_available():
326-
images = images.to('mps')
327-
target = target.to('mps')
320+
# move data to the same device as model
321+
images = images.to(device, non_blocking=True)
322+
target = target.to(device, non_blocking=True)
328323

329324
# compute output
330325
output = model(images)

0 commit comments

Comments
 (0)