@@ -246,10 +246,13 @@ def forward(self, x):
246246
247247correct = 0
248248total = 0
249+ # since we're not training, we don't need to calculate the gradients for our outputs
249250with torch .no_grad ():
250251 for data in testloader :
251252 images , labels = data
253+ # calculate outputs by running images through the network
252254 outputs = net (images )
255+ # the class with the highest energy is what we choose as prediction
253256 _ , predicted = torch .max (outputs .data , 1 )
254257 total += labels .size (0 )
255258 correct += (predicted == labels ).sum ().item ()
@@ -265,23 +268,28 @@ def forward(self, x):
265268# Hmmm, what are the classes that performed well, and the classes that did
266269# not perform well:
267270
268- class_correct = list (0. for i in range (10 ))
269- class_total = list (0. for i in range (10 ))
271+ # prepare to count predictions for each class
272+ correct_pred = {classname : 0 for classname in classes }
273+ total_pred = {classname : 0 for classname in classes }
274+
275+ # again no gradients needed
270276with torch .no_grad ():
271277 for data in testloader :
272- images , labels = data
273- outputs = net (images )
274- _ , predicted = torch .max (outputs , 1 )
275- c = (predicted == labels ).squeeze ()
276- for i in range (4 ):
277- label = labels [i ]
278- class_correct [label ] += c [i ].item ()
279- class_total [label ] += 1
280-
281-
282- for i in range (10 ):
283- print ('Accuracy of %5s : %2d %%' % (
284- classes [i ], 100 * class_correct [i ] / class_total [i ]))
278+ images , labels = data
279+ outputs = net (images )
280+ _ , predictions = torch .max (outputs , 1 )
281+ # collect the correct predictions for each class
282+ for label , prediction in zip (labels , predictions ):
283+ if label == prediction :
284+ correct_pred [classes [label ]] += 1
285+ total_pred [classes [label ]] += 1
286+
287+
288+ # print accuracy for each class
289+ for classname , correct_count in correct_pred .items ():
290+ accuracy = 100 * float (correct_count ) / total_pred [classname ]
291+ print ("Accuracy for class {:5s} is: {:.1f} %" .format (classname ,
292+ accuracy ))
285293
286294########################################################################
287295# Okay, so what next?
0 commit comments