@@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker):
397397 fn = mocker .patch ("torchvision.prototype.transforms.functional.pad" )
398398 # vfdev-5, Feature Request: let's store params as Transform attribute
399399 # This could be also helpful for users
400+ # Otherwise, we can mock transform._get_params
400401 torch .manual_seed (12 )
401402 _ = transform (inpt )
402403 torch .manual_seed (12 )
@@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
456457 inpt = mocker .MagicMock (spec = features .Image )
457458 # vfdev-5, Feature Request: let's store params as Transform attribute
458459 # This could be also helpful for users
460+ # Otherwise, we can mock transform._get_params
459461 torch .manual_seed (12 )
460462 _ = transform (inpt )
461463 torch .manual_seed (12 )
@@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
576578
577579 # vfdev-5, Feature Request: let's store params as Transform attribute
578580 # This could be also helpful for users
581+ # Otherwise, we can mock transform._get_params
579582 torch .manual_seed (12 )
580583 _ = transform (inpt )
581584 torch .manual_seed (12 )
@@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
645648
646649 # vfdev-5, Feature Request: let's store params as Transform attribute
647650 # This could be also helpful for users
651+ # Otherwise, we can mock transform._get_params
648652 torch .manual_seed (12 )
649653 _ = transform (inpt )
650654 torch .manual_seed (12 )
@@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker):
716720
717721 # vfdev-5, Feature Request: let's store params as Transform attribute
718722 # This could be also helpful for users
723+ # Otherwise, we can mock transform._get_params
719724 torch .manual_seed (12 )
720725 _ = transform (inpt )
721726 torch .manual_seed (12 )
@@ -795,10 +800,80 @@ def test__transform(self, distortion_scale, mocker):
795800 inpt .image_size = (24 , 32 )
796801 # vfdev-5, Feature Request: let's store params as Transform attribute
797802 # This could be also helpful for users
803+ # Otherwise, we can mock transform._get_params
798804 torch .manual_seed (12 )
799805 _ = transform (inpt )
800806 torch .manual_seed (12 )
801807 torch .rand (1 ) # random apply changes random state
802808 params = transform ._get_params (inpt )
803809
804810 fn .assert_called_once_with (inpt , ** params , fill = fill , interpolation = interpolation )
811+
812+
813+ class TestElasticTransform :
814+ def test_assertions (self ):
815+
816+ with pytest .raises (TypeError , match = "alpha should be float or a sequence of floats" ):
817+ transforms .ElasticTransform ({})
818+
819+ with pytest .raises (ValueError , match = "alpha is a sequence its length should be one of 2" ):
820+ transforms .ElasticTransform ([1.0 , 2.0 , 3.0 ])
821+
822+ with pytest .raises (ValueError , match = "alpha should be a sequence of floats" ):
823+ transforms .ElasticTransform ([1 , 2 ])
824+
825+ with pytest .raises (TypeError , match = "sigma should be float or a sequence of floats" ):
826+ transforms .ElasticTransform (1.0 , {})
827+
828+ with pytest .raises (ValueError , match = "sigma is a sequence its length should be one of 2" ):
829+ transforms .ElasticTransform (1.0 , [1.0 , 2.0 , 3.0 ])
830+
831+ with pytest .raises (ValueError , match = "sigma should be a sequence of floats" ):
832+ transforms .ElasticTransform (1.0 , [1 , 2 ])
833+
834+ with pytest .raises (TypeError , match = "Got inappropriate fill arg" ):
835+ transforms .ElasticTransform (1.0 , 2.0 , fill = "abc" )
836+
837+ def test__get_params (self , mocker ):
838+ alpha = 2.0
839+ sigma = 3.0
840+ transform = transforms .ElasticTransform (alpha , sigma )
841+ image = mocker .MagicMock (spec = features .Image )
842+ image .num_channels = 3
843+ image .image_size = (24 , 32 )
844+
845+ params = transform ._get_params (image )
846+
847+ h , w = image .image_size
848+ displacement = params ["displacement" ]
849+ assert displacement .shape == (1 , h , w , 2 )
850+ assert (- alpha / w <= displacement [0 , ..., 0 ]).all () and (displacement [0 , ..., 0 ] <= alpha / w ).all ()
851+ assert (- alpha / h <= displacement [0 , ..., 1 ]).all () and (displacement [0 , ..., 1 ] <= alpha / h ).all ()
852+
853+ @pytest .mark .parametrize ("alpha" , [5.0 , [5.0 , 10.0 ]])
854+ @pytest .mark .parametrize ("sigma" , [2.0 , [2.0 , 5.0 ]])
855+ def test__transform (self , alpha , sigma , mocker ):
856+ interpolation = InterpolationMode .BILINEAR
857+ fill = 12
858+ transform = transforms .ElasticTransform (alpha , sigma = sigma , fill = fill , interpolation = interpolation )
859+
860+ if isinstance (alpha , float ):
861+ assert transform .alpha == [alpha , alpha ]
862+ else :
863+ assert transform .alpha == alpha
864+
865+ if isinstance (sigma , float ):
866+ assert transform .sigma == [sigma , sigma ]
867+ else :
868+ assert transform .sigma == sigma
869+
870+ fn = mocker .patch ("torchvision.prototype.transforms.functional.elastic" )
871+ inpt = mocker .MagicMock (spec = features .Image )
872+ inpt .num_channels = 3
873+ inpt .image_size = (24 , 32 )
874+
875+ # Let's mock transform._get_params to control the output:
876+ transform ._get_params = mocker .MagicMock ()
877+ _ = transform (inpt )
878+ params = transform ._get_params (inpt )
879+ fn .assert_called_once_with (inpt , ** params , fill = fill , interpolation = interpolation )
0 commit comments