Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,35 @@ def test_frozenbatchnorm2d_repr(self):
expected_string = f"FrozenBatchNorm2d({num_features})"
self.assertEqual(t.__repr__(), expected_string)

def test_frozenbatchnorm2d_eps(self):
sample_size = (4, 32, 28, 28)
x = torch.rand(sample_size)
state_dict = dict(weight=torch.rand(sample_size[1]),
bias=torch.rand(sample_size[1]),
running_mean=torch.rand(sample_size[1]),
running_var=torch.rand(sample_size[1]),
num_batches_tracked=torch.tensor(100))

# Check that default eps is zero for backward-compatibility
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval()
bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))

# Check computation for eps > 0
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
bn.load_state_dict(state_dict)
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))

def test_frozenbatchnorm2d_n_arg(self):
"""Ensure a warning is thrown when passing `n` kwarg
(remove this when support of `n` is dropped)"""
self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32)


if __name__ == '__main__':
unittest.main()
19 changes: 13 additions & 6 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

import math
import warnings
import torch
from torchvision.ops import _new_empty_tensor
from torch.nn import Module, Conv2d
Expand Down Expand Up @@ -124,12 +125,18 @@ class FrozenBatchNorm2d(torch.nn.Module):
are fixed
"""

def __init__(self, n):
def __init__(self, num_features, eps=0., n=None):
# n=None for backward-compatibility
if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
DeprecationWarning)
num_features = n
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
Expand All @@ -148,7 +155,7 @@ def forward(self, x):
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * rv.rsqrt()
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
return x * scale + bias

Expand Down