Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions intermediate_source/quantized_transfer_learning_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@


######################################################################
# Load Data (section not needed as it is covered in the original tutorial)
# Load Data
# ------------------------------------------------------------------------
#
# ..Note :: This section is identical to the original transfer learning tutorial.
#
# We will use ``torchvision`` and ``torch.utils.data`` packages to load
# the data.
#
Expand Down Expand Up @@ -360,7 +362,7 @@ def visualize_model(model, rows=3, cols=3):
# **Notice that when isolating the feature extractor from a quantized
# model, you have to place the quantizer in the beginning and in the end
# of it.**
#
# We write a helper function to create a model with a custom head.

from torch import nn

Expand Down Expand Up @@ -394,8 +396,6 @@ def create_combined_model(model_fe):
)
return new_model

new_model = create_combined_model(model_fe)


######################################################################
# .. warning:: Currently the quantized models can only be run on CPU.
Expand All @@ -404,6 +404,7 @@ def create_combined_model(model_fe):
#

import torch.optim as optim
new_model = create_combined_model(model_fe)
new_model = new_model.to('cpu')

criterion = nn.CrossEntropyLoss()
Expand Down Expand Up @@ -431,7 +432,7 @@ def create_combined_model(model_fe):


######################################################################
# **Part 2. Finetuning the quantizable model**
# Part 2. Finetuning the quantizable model
#
# In this part, we fine tune the feature extractor used for transfer
# learning, and quantize the feature extractor. Note that in both part 1
Expand All @@ -446,18 +447,21 @@ def create_combined_model(model_fe):
# datasets.
#
# The pretrained feature extractor must be quantizable, i.e we need to do
# the following: 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
# using torch.quantization.fuse_modules. 2. Connect the feature extractor
# with a custom head. This requires dequantizing the output of the feature
# extractor. 3. Insert fake-quantization modules at appropriate locations
# in the feature extractor to mimic quantization during training.
# the following:
# 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
# using torch.quantization.fuse_modules.
# 2. Connect the feature extractor
# with a custom head. This requires dequantizing the output of the feature
# extractor.
# 3. Insert fake-quantization modules at appropriate locations
# in the feature extractor to mimic quantization during training.
#
# For step (1), we use models from torchvision/models/quantization, which
# support a member method fuse_model, which fuses all the conv, bn, and
# relu modules. In general, this would require calling the
# torch.quantization.fuse_modules API with the list of modules to fuse.
#
# Step (2) is done by the function create_custom_model function that we
# Step (2) is done by the function create_combined_model function that we
# used in the previous section.
#
# Step (3) is achieved by using torch.quantization.prepare_qat, which
Expand Down Expand Up @@ -534,4 +538,3 @@ def create_combined_model(model_fe):
plt.ioff()
plt.tight_layout()
plt.show()