From 485466c4005eb372efd969a3cf85d1ac880680f5 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Thu, 7 May 2020 01:31:02 +0200 Subject: [PATCH 1/9] feat: Added eps argument to FrozenBatchNorm2d --- torchvision/ops/misc.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 0a292342e3b..bbac71122e3 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -124,12 +124,13 @@ class FrozenBatchNorm2d(torch.nn.Module): are fixed """ - def __init__(self, n): + def __init__(self, num_features, eps=0.): super(FrozenBatchNorm2d, self).__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -142,15 +143,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def forward(self, x): - # move reshapes to the beginning - # to make it fuser-friendly - w = self.weight.reshape(1, -1, 1, 1) - b = self.bias.reshape(1, -1, 1, 1) - rv = self.running_var.reshape(1, -1, 1, 1) - rm = self.running_mean.reshape(1, -1, 1, 1) - scale = w * rv.rsqrt() - bias = b - rm * scale - return x * scale + bias + # Scaling factor + scale = self.weight * (self.running_var + self.eps).rsqrt() + # Bias + bias = (self.bias - self.running_mean * scale) + + return x * scale.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) def __repr__(self): return f"{self.__class__.__name__}({self.weight.shape[0]})" From 58a9aedcb1a77404da244e9503101c953c88d4ef Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Thu, 7 May 2020 01:34:54 +0200 Subject: [PATCH 2/9] test: Added unittest for eps addition in FrozenBatchNorm2d See #2169 --- test/test_ops.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 86e1c2b0ba7..4b10d301305 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -547,6 +547,30 @@ def test_frozenbatchnorm2d_repr(self): expected_string = f"FrozenBatchNorm2d({num_features})" self.assertEqual(t.__repr__(), expected_string) + def test_frozenbatchnorm2d_eps(self): + sample_size = (4, 32, 28, 28) + x = torch.rand(sample_size) + state_dict = dict(weight=torch.rand(sample_size[1]), + bias=torch.rand(sample_size[1]), + running_mean=torch.rand(sample_size[1]), + running_var=torch.rand(sample_size[1]), + num_batches_tracked=torch.tensor(100)) + + # Check that default eps is zero for backward-compatibility + fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) + fbn.load_state_dict(state_dict, strict=False) + bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval() + bn.load_state_dict(state_dict) + # Difference is expected to fall in an acceptable range + self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + + # Check computation for eps > 0 + fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5) + fbn.load_state_dict(state_dict, strict=False) + bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval() + bn.load_state_dict(state_dict) + self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + if __name__ == '__main__': unittest.main() From ecd1ae9ded0be1bf1535072c3f1a0872f31b11c9 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 13:16:28 +0200 Subject: [PATCH 3/9] fix: Reverted forward changes for JIT fuser --- torchvision/ops/misc.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index bbac71122e3..f38ef9df4c4 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -143,12 +143,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def forward(self, x): - # Scaling factor - scale = self.weight * (self.running_var + self.eps).rsqrt() - # Bias - bias = (self.bias - self.running_mean * scale) - - return x * scale.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * rv.rsqrt() + bias = b - rm * scale + return x * scale + bias def __repr__(self): return f"{self.__class__.__name__}({self.weight.shape[0]})" From 30c29da5a6b017e6a8251d91b9f68a467602f783 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 13:19:18 +0200 Subject: [PATCH 4/9] fix: Added back n argument for backward-compatibility --- torchvision/ops/misc.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index f38ef9df4c4..1f76e90b239 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -13,6 +13,7 @@ """ import math +import warnings import torch from torchvision.ops import _new_empty_tensor from torch.nn import Module, Conv2d @@ -124,7 +125,11 @@ class FrozenBatchNorm2d(torch.nn.Module): are fixed """ - def __init__(self, num_features, eps=0.): + def __init__(self, num_features, eps=0., n=None): + # n=None for backward-compatibility + if n is not None: + warnings.warn("`n` argument is deprecated and has been renamed `num_features`") + num_features = n super(FrozenBatchNorm2d, self).__init__() self.eps = eps self.register_buffer("weight", torch.ones(num_features)) From 3b47a42e38f3daff9a5419b32d2d20d4ed100799 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 14:52:25 +0200 Subject: [PATCH 5/9] fix: Fixed FrozenBatchNorm2d forward Added back eps --- torchvision/ops/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 1f76e90b239..950aeb10e93 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -154,7 +154,7 @@ def forward(self, x): b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) - scale = w * rv.rsqrt() + scale = w * (rv + self.eps).rsqrt() bias = b - rm * scale return x * scale + bias From 31bfa5ce8e611d1cd48eab0daa5989b878b6fb36 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 14:57:42 +0200 Subject: [PATCH 6/9] feat: Specified deprecation warnings in FrozenBatchNorm2d --- torchvision/ops/misc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 950aeb10e93..b8b71bfedb4 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -128,7 +128,8 @@ class FrozenBatchNorm2d(torch.nn.Module): def __init__(self, num_features, eps=0., n=None): # n=None for backward-compatibility if n is not None: - warnings.warn("`n` argument is deprecated and has been renamed `num_features`") + warnings.warn("`n` argument is deprecated and has been renamed `num_features`", + DeprecationWarning) num_features = n super(FrozenBatchNorm2d, self).__init__() self.eps = eps From 73a36c822008fd18797a26ad9641a4cf16c159a0 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 14:58:05 +0200 Subject: [PATCH 7/9] test: Added unittest for deprecation warninig in FrozenBatchNorm2d --- test/test_ops.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 4b10d301305..251bb294f1f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -571,6 +571,11 @@ def test_frozenbatchnorm2d_eps(self): bn.load_state_dict(state_dict) self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + def test_frozenbatchnorm2d_n_arg(self): + # Ensure a warning is thrown if user try to pass `n` kwarg + # Remove this test when support of `n` is dropped + self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32) + if __name__ == '__main__': unittest.main() From b05f9e2f1d2430d5784b2536d49c39ab7a3cbdca Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 15:04:00 +0200 Subject: [PATCH 8/9] style: Fixed lint --- test/test_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 251bb294f1f..2423dc6656f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -572,8 +572,7 @@ def test_frozenbatchnorm2d_eps(self): self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) def test_frozenbatchnorm2d_n_arg(self): - # Ensure a warning is thrown if user try to pass `n` kwarg - # Remove this test when support of `n` is dropped + # Ensure a warning is thrown when passing `n` kwarg (remove this when support of `n` is dropped) self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32) From 611b9b37ff70b3e580263f9d54feadb1d3b3d46b Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Mon, 11 May 2020 15:13:23 +0200 Subject: [PATCH 9/9] style: Fixed block comment lint --- test/test_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 2423dc6656f..c99b6a00ff9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -572,7 +572,8 @@ def test_frozenbatchnorm2d_eps(self): self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) def test_frozenbatchnorm2d_n_arg(self): - # Ensure a warning is thrown when passing `n` kwarg (remove this when support of `n` is dropped) + """Ensure a warning is thrown when passing `n` kwarg + (remove this when support of `n` is dropped)""" self.assertWarns(DeprecationWarning, ops.misc.FrozenBatchNorm2d, 32, eps=1e-5, n=32)