11Visualizing Models, Data, and Training with TensorBoard
22=======================================================
33
4- In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html >`_,
4+ In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html >`_,
55we show you how to load in data,
66feed it through a model we define as a subclass of ``nn.Module ``,
77train this model on training data, and test it on test data.
@@ -348,7 +348,7 @@ In the prior tutorial, we looked at per-class accuracy once the model
348348had been trained; here, we'll use TensorBoard to plot precision-recall
349349curves (good explanation
350350`here <https://www.scikit-yb.org/en/latest/api/classifier/prcurve.html >`__)
351- for each class.
351+ for each class.
352352
3533536. Assessing trained models with TensorBoard
354354~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -359,38 +359,37 @@ for each class.
359359 # 2. gets the preds in a test_size Tensor
360360 # takes ~10 seconds to run
361361 class_probs = []
362- class_preds = []
362+ class_label = []
363363 with torch.no_grad():
364364 for data in testloader:
365365 images, labels = data
366366 output = net(images)
367367 class_probs_batch = [F.softmax(el, dim = 0 ) for el in output]
368- _, class_preds_batch = torch.max(output, 1 )
369368
370369 class_probs.append(class_probs_batch)
371- class_preds .append(class_preds_batch )
370+ class_label .append(labels )
372371
373372 test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
374- test_preds = torch.cat(class_preds )
373+ test_label = torch.cat(class_label )
375374
376375 # helper function
377- def add_pr_curve_tensorboard (class_index , test_probs , test_preds , global_step = 0 ):
376+ def add_pr_curve_tensorboard (class_index , test_probs , test_label , global_step = 0 ):
378377 '''
379378 Takes in a "class_index" from 0 to 9 and plots the corresponding
380379 precision-recall curve
381380 '''
382- tensorboard_preds = test_preds == class_index
381+ tensorboard_truth = test_label == class_index
383382 tensorboard_probs = test_probs[:, class_index]
384383
385384 writer.add_pr_curve(classes[class_index],
386- tensorboard_preds ,
385+ tensorboard_truth ,
387386 tensorboard_probs,
388387 global_step = global_step)
389388 writer.close()
390389
391390 # plot all the pr curves
392391 for i in range (len (classes)):
393- add_pr_curve_tensorboard(i, test_probs, test_preds )
392+ add_pr_curve_tensorboard(i, test_probs, test_label )
394393
395394 You will now see a "PR Curves" tab that contains the precision-recall
396395curves for each class. Go ahead and poke around; you'll see that on
0 commit comments