8484
8585
8686######################################################################
87- # Load Data (section not needed as it is covered in the original tutorial)
87+ # Load Data
8888# ------------------------------------------------------------------------
8989#
90+ # ..Note :: This section is identical to the original transfer learning tutorial.
91+ #
9092# We will use ``torchvision`` and ``torch.utils.data`` packages to load
9193# the data.
9294#
@@ -360,7 +362,7 @@ def visualize_model(model, rows=3, cols=3):
360362# **Notice that when isolating the feature extractor from a quantized
361363# model, you have to place the quantizer in the beginning and in the end
362364# of it.**
363- #
365+ # We write a helper function to create a model with a custom head.
364366
365367from torch import nn
366368
@@ -394,8 +396,6 @@ def create_combined_model(model_fe):
394396 )
395397 return new_model
396398
397- new_model = create_combined_model (model_fe )
398-
399399
400400######################################################################
401401# .. warning:: Currently the quantized models can only be run on CPU.
@@ -404,6 +404,7 @@ def create_combined_model(model_fe):
404404#
405405
406406import torch .optim as optim
407+ new_model = create_combined_model (model_fe )
407408new_model = new_model .to ('cpu' )
408409
409410criterion = nn .CrossEntropyLoss ()
@@ -431,7 +432,7 @@ def create_combined_model(model_fe):
431432
432433
433434######################################################################
434- # ** Part 2. Finetuning the quantizable model**
435+ # Part 2. Finetuning the quantizable model
435436#
436437# In this part, we fine tune the feature extractor used for transfer
437438# learning, and quantize the feature extractor. Note that in both part 1
@@ -446,18 +447,21 @@ def create_combined_model(model_fe):
446447# datasets.
447448#
448449# The pretrained feature extractor must be quantizable, i.e we need to do
449- # the following: 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
450- # using torch.quantization.fuse_modules. 2. Connect the feature extractor
451- # with a custom head. This requires dequantizing the output of the feature
452- # extractor. 3. Insert fake-quantization modules at appropriate locations
453- # in the feature extractor to mimic quantization during training.
450+ # the following:
451+ # 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
452+ # using torch.quantization.fuse_modules.
453+ # 2. Connect the feature extractor
454+ # with a custom head. This requires dequantizing the output of the feature
455+ # extractor.
456+ # 3. Insert fake-quantization modules at appropriate locations
457+ # in the feature extractor to mimic quantization during training.
454458#
455459# For step (1), we use models from torchvision/models/quantization, which
456460# support a member method fuse_model, which fuses all the conv, bn, and
457461# relu modules. In general, this would require calling the
458462# torch.quantization.fuse_modules API with the list of modules to fuse.
459463#
460- # Step (2) is done by the function create_custom_model function that we
464+ # Step (2) is done by the function create_combined_model function that we
461465# used in the previous section.
462466#
463467# Step (3) is achieved by using torch.quantization.prepare_qat, which
@@ -534,4 +538,3 @@ def create_combined_model(model_fe):
534538plt .ioff ()
535539plt .tight_layout ()
536540plt .show ()
537-
0 commit comments