@@ -242,10 +242,13 @@ def forward(self, x):
242242
243243correct = 0
244244total = 0
245+ # since we're not training, we don't need to calculate the gradients for our outputs
245246with torch .no_grad ():
246247 for data in testloader :
247248 images , labels = data
249+ # calculate outputs by running images through the network
248250 outputs = net (images )
251+ # the class with the highest energy is what we choose as prediction
249252 _ , predicted = torch .max (outputs .data , 1 )
250253 total += labels .size (0 )
251254 correct += (predicted == labels ).sum ().item ()
@@ -261,23 +264,28 @@ def forward(self, x):
261264# Hmmm, what are the classes that performed well, and the classes that did
262265# not perform well:
263266
264- class_correct = list (0. for i in range (10 ))
265- class_total = list (0. for i in range (10 ))
267+ # prepare to count predictions for each class
268+ correct_pred = {classname : 0 for classname in classes }
269+ total_pred = {classname : 0 for classname in classes }
270+
271+ # again no gradients needed
266272with torch .no_grad ():
267273 for data in testloader :
268- images , labels = data
269- outputs = net (images )
270- _ , predicted = torch .max (outputs , 1 )
271- c = (predicted == labels ).squeeze ()
272- for i in range (4 ):
273- label = labels [i ]
274- class_correct [label ] += c [i ].item ()
275- class_total [label ] += 1
276-
277-
278- for i in range (10 ):
279- print ('Accuracy of %5s : %2d %%' % (
280- classes [i ], 100 * class_correct [i ] / class_total [i ]))
274+ images , labels = data
275+ outputs = net (images )
276+ _ , predictions = torch .max (outputs , 1 )
277+ # collect the correct predictions for each class
278+ for label , prediction in zip (labels , predictions ):
279+ if label == prediction :
280+ correct_pred [classes [label ]] += 1
281+ total_pred [classes [label ]] += 1
282+
283+
284+ # print accuracy for each class
285+ for classname , correct_count in correct_pred .items ():
286+ accuracy = 100 * float (correct_count ) / total_pred [classname ]
287+ print ("Accuracy for class {:5s} is: {:.1f} %" .format (classname ,
288+ accuracy ))
281289
282290########################################################################
283291# Okay, so what next?
0 commit comments