diff --git a/test/test_ops.py b/test/test_ops.py index 86e1c2b0ba7..c99b6a00ff9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -547,6 +547,35 @@ def test_frozenbatchnorm2d_repr(self): expected_string = f"FrozenBatchNorm2d({num_features})" self.assertEqual(t.__repr__(), expected_string) + def test_frozenbatchnorm2d_eps(self): + sample_size = (4, 32, 28, 28) + x = torch.rand(sample_size) + state_dict = dict(weight=torch.rand(sample_size[1]), + bias=torch.rand(sample_size[1]), + running_mean=torch.rand(sample_size[1]), + running_var=torch.rand(sample_size[1]), + num_batches_tracked=torch.tensor(100)) + + # Check that default eps is zero for backward-compatibility + fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) + fbn.load_state_dict(state_dict, strict=False) + bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval() + bn.load_state_dict(state_dict) + # Difference is expected to fall in an acceptable range + self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + + # Check computation for eps > 0 + fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5) + fbn.load_state_dict(state_dict, strict=False) + bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval() + bn.load_state_dict(state_dict) + self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + + def test_frozenbatchnorm2d_n_arg(self): + """Ensure a warning is thrown when passing `n` kwarg + (remove this when support of `n` is dropped)""" + self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 0a292342e3b..b8b71bfedb4 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -13,6 +13,7 @@ """ import math +import warnings import torch from torchvision.ops import _new_empty_tensor from torch.nn import Module, Conv2d @@ -124,12 +125,18 @@ class FrozenBatchNorm2d(torch.nn.Module): are fixed """ - def __init__(self, n): + def __init__(self, num_features, eps=0., n=None): + # n=None for backward-compatibility + if n is not None: + warnings.warn("`n` argument is deprecated and has been renamed `num_features`", + DeprecationWarning) + num_features = n super(FrozenBatchNorm2d, self).__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -148,7 +155,7 @@ def forward(self, x): b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) - scale = w * rv.rsqrt() + scale = w * (rv + self.eps).rsqrt() bias = b - rm * scale return x * scale + bias