@@ -972,8 +972,9 @@ def test__get_params(self, value, mocker):
972972 assert 0 <= i <= image .image_size [0 ] - h
973973 assert 0 <= j <= image .image_size [1 ] - w
974974
975- def test__transform (self , mocker ):
976- transform = transforms .RandomErasing ()
975+ @pytest .mark .parametrize ("p" , [0 , 1 ])
976+ def test__transform (self , mocker , p ):
977+ transform = transforms .RandomErasing (p = p )
977978 transform ._transformed_types = (mocker .MagicMock ,)
978979
979980 i_sentinel = mocker .MagicMock ()
@@ -989,11 +990,15 @@ def test__transform(self, mocker):
989990 inpt_sentinel = mocker .MagicMock ()
990991
991992 mock = mocker .patch ("torchvision.prototype.transforms._augment.F.erase" )
992- transform (inpt_sentinel )
993+ output = transform (inpt_sentinel )
993994
994- mock .assert_called_once_with (
995- inpt_sentinel , i = i_sentinel , j = j_sentinel , h = h_sentinel , w = w_sentinel , v = v_sentinel
996- )
995+ if p :
996+ mock .assert_called_once_with (
997+ inpt_sentinel , i = i_sentinel , j = j_sentinel , h = h_sentinel , w = w_sentinel , v = v_sentinel
998+ )
999+ else :
1000+ mock .assert_not_called ()
1001+ assert output is inpt_sentinel
9971002
9981003
9991004class TestTransform :
0 commit comments