Skip to content

Commit ddfbfd3

Browse files
balakv504facebook-github-bot
authored andcommitted
Fix fake fusion for convolutions without bias (#3353)
Summary: Fix `AttributeError` when performing fake fusion on convolution layers without bias by creating a zero-filled bias parameter instead of attempting to access requires_grad on None. Reviewed By: jerryzh168 Differential Revision: D87356763
1 parent 4f5bc7a commit ddfbfd3

File tree

3 files changed

+20
-25
lines changed

3 files changed

+20
-25
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -157,29 +157,23 @@ def test_chunked_bn_fusion(self):
157157
n_chunks = 3
158158
in_channels = 1
159159
out_channels = 32
160-
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels)
161-
m.bn.running_var = torch.nn.Parameter(
162-
torch.rand(out_channels) * 1e-2, requires_grad=False
163-
)
160+
for bias in [True, False]:
161+
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels, bias=bias)
162+
m.bn.running_var = torch.nn.Parameter(
163+
torch.rand(out_channels) * 1e-2, requires_grad=False
164+
)
164165

165-
m.eval()
166-
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
167-
ref_outputs = m(*example_inputs)
168-
traced_model = torch.export.export(m, example_inputs, strict=True).module()
169-
traced_outputs = traced_model(*example_inputs)
170-
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
171-
prepared_outputs = prepared_model(*example_inputs)
172-
173-
if isinstance(ref_outputs, (tuple, list)):
174-
for ref, prepared, traced in zip(
175-
ref_outputs, prepared_outputs, traced_outputs
176-
):
177-
torch.testing.assert_close(ref, traced)
178-
torch.testing.assert_close(traced, prepared)
179-
else:
166+
m.eval()
167+
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
168+
ref_outputs = m(*example_inputs)
169+
traced_model = torch.export.export(m, example_inputs, strict=True).module()
170+
traced_outputs = traced_model(*example_inputs)
171+
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
172+
prepared_outputs = prepared_model(*example_inputs)
180173
torch.testing.assert_close(ref_outputs, traced_outputs)
181174
torch.testing.assert_close(traced_outputs, prepared_outputs)
182175

176+
183177
def test_wo_annotate_conv_output_quantizer(self):
184178
# TODO: use OP_TO_ANNOTATOR
185179
class BackendAQuantizer(Quantizer):

torchao/quantization/pt2e/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,11 @@ def fold_bn_weights_into_conv_node(
710710
conv_args.append(None)
711711

712712
if fake_fuse:
713-
fused_weight, fused_bias = (
714-
torch.nn.Parameter(conv_w, conv_w.requires_grad),
715-
torch.nn.Parameter(conv_b, conv_b.requires_grad),
716-
)
713+
fused_weight = torch.nn.Parameter(conv_w, conv_w.requires_grad)
714+
if conv_b is not None:
715+
fused_bias = torch.nn.Parameter(conv_b, conv_b.requires_grad)
716+
else:
717+
fused_bias = torch.nn.Parameter(torch.zeros_like(bn_rm), requires_grad=conv_w.requires_grad)
717718
else:
718719
fused_weight, fused_bias = fuse_conv_bn_weights(
719720
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose

torchao/testing/model_architectures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ def forward(self, x):
8282

8383
class ConvWithSharedWeightInExportedModel(nn.Module):
8484
def __init__(
85-
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1
85+
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
8686
) -> None:
8787
super().__init__()
8888
self.n_chunks = n_chunks
89-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
89+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
9090
self.bn = nn.BatchNorm2d(out_channels)
9191
self.relu = nn.ReLU(inplace=True)
9292

0 commit comments

Comments
 (0)