@@ -561,39 +561,31 @@ The training process may take many minutes, depending on a number of factors, su
561561After executing the cell above, you can visualize the training and test set errors and accuracy for an instance of this training process.
562562
563563``` {code-cell}
564+ epoch_range = np.arange(epochs) + 1 # Starting from 1
565+
564566# The training set metrics.
565- y_training_error = [
566- store_training_loss[i] / float(len(training_images))
567- for i in range(len(store_training_loss))
568- ]
569- x_training_error = range(1, len(store_training_loss) + 1)
570- y_training_accuracy = [
571- store_training_accurate_pred[i] / float(len(training_images))
572- for i in range(len(store_training_accurate_pred))
573- ]
574- x_training_accuracy = range(1, len(store_training_accurate_pred) + 1)
567+ training_metrics = {
568+ "accuracy": np.asarray(store_training_accurate_pred) / len(training_images),
569+ "error": np.asarray(store_training_loss) / len(training_images),
570+ }
575571
576572# The test set metrics.
577- y_test_error = [
578- store_test_loss[i] / float(len(test_images)) for i in range(len(store_test_loss))
579- ]
580- x_test_error = range(1, len(store_test_loss) + 1)
581- y_test_accuracy = [
582- store_training_accurate_pred[i] / float(len(training_images))
583- for i in range(len(store_training_accurate_pred))
584- ]
585- x_test_accuracy = range(1, len(store_test_accurate_pred) + 1)
573+ test_metrics = {
574+ "accuracy": np.asarray(store_test_accurate_pred) / len(test_images),
575+ "error": np.asarray(store_test_loss) / len(test_images),
576+ }
586577
587578# Display the plots.
588579fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
589- axes[0].set_title("Training set error, accuracy")
590- axes[0].plot(x_training_accuracy, y_training_accuracy, label="Training set accuracy")
591- axes[0].plot(x_training_error, y_training_error, label="Training set error")
592- axes[0].set_xlabel("Epochs")
593- axes[1].set_title("Test set error, accuracy")
594- axes[1].plot(x_test_accuracy, y_test_accuracy, label="Test set accuracy")
595- axes[1].plot(x_test_error, y_test_error, label="Test set error")
596- axes[1].set_xlabel("Epochs")
580+ for ax, metrics, title in zip(
581+ axes, (training_metrics, test_metrics), ("Training set", "Test set")
582+ ):
583+ # Plot the metrics
584+ for metric, values in metrics.items():
585+ ax.plot(epoch_range, values, label=metric.capitalize())
586+ ax.set_title(title)
587+ ax.set_xlabel("Epochs")
588+ ax.legend()
597589plt.show()
598590```
599591
0 commit comments