Skip to content

Commit 7a2d061

Browse files
authored
Added eps attribute to FrozenBatchNorm2d (#2190)
* feat: Added eps argument to FrozenBatchNorm2d * test: Added unittest for eps addition in FrozenBatchNorm2d See #2169 * fix: Reverted forward changes for JIT fuser * fix: Added back n argument for backward-compatibility * fix: Fixed FrozenBatchNorm2d forward Added back eps * feat: Specified deprecation warnings in FrozenBatchNorm2d * test: Added unittest for deprecation warninig in FrozenBatchNorm2d * style: Fixed lint * style: Fixed block comment lint
1 parent a09d129 commit 7a2d061

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

test/test_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,35 @@ def test_frozenbatchnorm2d_repr(self):
547547
expected_string = f"FrozenBatchNorm2d({num_features})"
548548
self.assertEqual(t.__repr__(), expected_string)
549549

550+
def test_frozenbatchnorm2d_eps(self):
551+
sample_size = (4, 32, 28, 28)
552+
x = torch.rand(sample_size)
553+
state_dict = dict(weight=torch.rand(sample_size[1]),
554+
bias=torch.rand(sample_size[1]),
555+
running_mean=torch.rand(sample_size[1]),
556+
running_var=torch.rand(sample_size[1]),
557+
num_batches_tracked=torch.tensor(100))
558+
559+
# Check that default eps is zero for backward-compatibility
560+
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
561+
fbn.load_state_dict(state_dict, strict=False)
562+
bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval()
563+
bn.load_state_dict(state_dict)
564+
# Difference is expected to fall in an acceptable range
565+
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
566+
567+
# Check computation for eps > 0
568+
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
569+
fbn.load_state_dict(state_dict, strict=False)
570+
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
571+
bn.load_state_dict(state_dict)
572+
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
573+
574+
def test_frozenbatchnorm2d_n_arg(self):
575+
"""Ensure a warning is thrown when passing `n` kwarg
576+
(remove this when support of `n` is dropped)"""
577+
self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32)
578+
550579

551580
class BoxConversionTester(unittest.TestCase):
552581
@staticmethod

torchvision/ops/misc.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
import math
16+
import warnings
1617
import torch
1718
from torchvision.ops import _new_empty_tensor
1819
from torch.nn import Module, Conv2d
@@ -124,12 +125,18 @@ class FrozenBatchNorm2d(torch.nn.Module):
124125
are fixed
125126
"""
126127

127-
def __init__(self, n):
128+
def __init__(self, num_features, eps=0., n=None):
129+
# n=None for backward-compatibility
130+
if n is not None:
131+
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
132+
DeprecationWarning)
133+
num_features = n
128134
super(FrozenBatchNorm2d, self).__init__()
129-
self.register_buffer("weight", torch.ones(n))
130-
self.register_buffer("bias", torch.zeros(n))
131-
self.register_buffer("running_mean", torch.zeros(n))
132-
self.register_buffer("running_var", torch.ones(n))
135+
self.eps = eps
136+
self.register_buffer("weight", torch.ones(num_features))
137+
self.register_buffer("bias", torch.zeros(num_features))
138+
self.register_buffer("running_mean", torch.zeros(num_features))
139+
self.register_buffer("running_var", torch.ones(num_features))
133140

134141
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
135142
missing_keys, unexpected_keys, error_msgs):
@@ -148,7 +155,7 @@ def forward(self, x):
148155
b = self.bias.reshape(1, -1, 1, 1)
149156
rv = self.running_var.reshape(1, -1, 1, 1)
150157
rm = self.running_mean.reshape(1, -1, 1, 1)
151-
scale = w * rv.rsqrt()
158+
scale = w * (rv + self.eps).rsqrt()
152159
bias = b - rm * scale
153160
return x * scale + bias
154161

0 commit comments

Comments
 (0)