Skip to content

Commit 14e37fc

Browse files
balakv504facebook-github-bot
authored andcommitted
Fix fake fusion for convolutions without bias (#3353)
Summary: Fix `AttributeError` when performing fake fusion on convolution layers without bias by creating a zero-filled bias parameter instead of attempting to access requires_grad on None. Differential Revision: D87356763
1 parent 4f5bc7a commit 14e37fc

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
# ruff: noqa: F841
1010

1111

12+
13+
14+
15+
1216
import unittest
1317

1418
import torch
@@ -157,28 +161,30 @@ def test_chunked_bn_fusion(self):
157161
n_chunks = 3
158162
in_channels = 1
159163
out_channels = 32
160-
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels)
161-
m.bn.running_var = torch.nn.Parameter(
162-
torch.rand(out_channels) * 1e-2, requires_grad=False
163-
)
164+
for bias in [True, False]:
165+
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels, bias=bias)
166+
m.bn.running_var = torch.nn.Parameter(
167+
torch.rand(out_channels) * 1e-2, requires_grad=False
168+
)
169+
170+
m.eval()
171+
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
172+
ref_outputs = m(*example_inputs)
173+
traced_model = torch.export.export(m, example_inputs, strict=True).module()
174+
traced_outputs = traced_model(*example_inputs)
175+
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
176+
prepared_outputs = prepared_model(*example_inputs)
177+
178+
if isinstance(ref_outputs, (tuple, list)):
179+
for ref, prepared, traced in zip(
180+
ref_outputs, prepared_outputs, traced_outputs
181+
):
182+
torch.testing.assert_close(ref, traced)
183+
torch.testing.assert_close(traced, prepared)
184+
else:
185+
torch.testing.assert_close(ref_outputs, traced_outputs)
186+
torch.testing.assert_close(traced_outputs, prepared_outputs)
164187

165-
m.eval()
166-
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
167-
ref_outputs = m(*example_inputs)
168-
traced_model = torch.export.export(m, example_inputs, strict=True).module()
169-
traced_outputs = traced_model(*example_inputs)
170-
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
171-
prepared_outputs = prepared_model(*example_inputs)
172-
173-
if isinstance(ref_outputs, (tuple, list)):
174-
for ref, prepared, traced in zip(
175-
ref_outputs, prepared_outputs, traced_outputs
176-
):
177-
torch.testing.assert_close(ref, traced)
178-
torch.testing.assert_close(traced, prepared)
179-
else:
180-
torch.testing.assert_close(ref_outputs, traced_outputs)
181-
torch.testing.assert_close(traced_outputs, prepared_outputs)
182188

183189
def test_wo_annotate_conv_output_quantizer(self):
184190
# TODO: use OP_TO_ANNOTATOR

torchao/quantization/pt2e/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,11 @@ def fold_bn_weights_into_conv_node(
710710
conv_args.append(None)
711711

712712
if fake_fuse:
713-
fused_weight, fused_bias = (
714-
torch.nn.Parameter(conv_w, conv_w.requires_grad),
715-
torch.nn.Parameter(conv_b, conv_b.requires_grad),
716-
)
713+
fused_weight = torch.nn.Parameter(conv_w, conv_w.requires_grad)
714+
if conv_b is not None:
715+
fused_bias = torch.nn.Parameter(conv_b, conv_b.requires_grad)
716+
else:
717+
fused_bias = torch.nn.Parameter(torch.zeros_like(bn_rm), requires_grad=conv_w.requires_grad)
717718
else:
718719
fused_weight, fused_bias = fuse_conv_bn_weights(
719720
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose

torchao/testing/model_architectures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ def forward(self, x):
8282

8383
class ConvWithSharedWeightInExportedModel(nn.Module):
8484
def __init__(
85-
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1
85+
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
8686
) -> None:
8787
super().__init__()
8888
self.n_chunks = n_chunks
89-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
89+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
9090
self.bn = nn.BatchNorm2d(out_channels)
9191
self.relu = nn.ReLU(inplace=True)
9292

0 commit comments

Comments
 (0)