Skip to content

Commit 4304bb4

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] fix prototype RandomErasing test (#6472)
Reviewed By: datumbox Differential Revision: D39013678 fbshipit-source-id: 6e7c7ef4d2c8a15eae9b427128aef9172eca6ad2
1 parent 72d7788 commit 4304bb4

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

test/test_prototype_transforms.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -972,8 +972,9 @@ def test__get_params(self, value, mocker):
972972
assert 0 <= i <= image.image_size[0] - h
973973
assert 0 <= j <= image.image_size[1] - w
974974

975-
def test__transform(self, mocker):
976-
transform = transforms.RandomErasing()
975+
@pytest.mark.parametrize("p", [0, 1])
976+
def test__transform(self, mocker, p):
977+
transform = transforms.RandomErasing(p=p)
977978
transform._transformed_types = (mocker.MagicMock,)
978979

979980
i_sentinel = mocker.MagicMock()
@@ -989,11 +990,15 @@ def test__transform(self, mocker):
989990
inpt_sentinel = mocker.MagicMock()
990991

991992
mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase")
992-
transform(inpt_sentinel)
993+
output = transform(inpt_sentinel)
993994

994-
mock.assert_called_once_with(
995-
inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel
996-
)
995+
if p:
996+
mock.assert_called_once_with(
997+
inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel
998+
)
999+
else:
1000+
mock.assert_not_called()
1001+
assert output is inpt_sentinel
9971002

9981003

9991004
class TestTransform:

0 commit comments

Comments
 (0)