Skip to content

Commit 005355b

Browse files
authored
Added eps in the __repr__ of FrozenBN (#2852)
* feat: Updated FrozenBN eps to align with BatchNorm * feat: Added eps to __repr__ of FrozenBN * test: Updated unittest of __repr__ for FrozenBN * test: Updated unittest for eps value in BN and FrozenBN * fix: Revert FrozenBN eps value * test: Revert test on eps alignment between FrozenBN and BN
1 parent 6713f03 commit 005355b

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

test/test_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,11 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
607607
class FrozenBNTester(unittest.TestCase):
608608
def test_frozenbatchnorm2d_repr(self):
609609
num_features = 32
610-
t = ops.misc.FrozenBatchNorm2d(num_features)
610+
eps = 1e-5
611+
t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
611612

612613
# Check integrity of object __repr__ attribute
613-
expected_string = f"FrozenBatchNorm2d({num_features})"
614+
expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
614615
self.assertEqual(t.__repr__(), expected_string)
615616

616617
def test_frozenbatchnorm2d_eps(self):

torchvision/ops/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,4 @@ def forward(self, x: Tensor) -> Tensor:
9696
return x * scale + bias
9797

9898
def __repr__(self) -> str:
99-
return f"{self.__class__.__name__}({self.weight.shape[0]})"
99+
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"

0 commit comments

Comments
 (0)