diff --git a/imagenet/main.py b/imagenet/main.py index 9cc4937c3f..276069e37e 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -279,6 +279,10 @@ def train(train_loader, model, criterion, optimizer, epoch, args): # compute output output = model(input) + # for googlenet case, there has three attributes + # logits, aux_logits2, aux_logits1 + if (hasattr(output, 'logits')): + output = output.logits loss = criterion(output, target) # measure accuracy and record loss