@@ -4040,3 +4040,28 @@ def test_transform_params_correctness(self, side_range, make_input, device):
40404040 assert 0 <= padding [1 ] <= (side_range [1 ] - 1 ) * height
40414041 assert 0 <= padding [2 ] <= (side_range [1 ] - 1 ) * width
40424042 assert 0 <= padding [3 ] <= (side_range [1 ] - 1 ) * height
4043+
4044+
4045+ class TestRandomPhotometricDistort :
4046+ # Tests are light because this largely relies on the already tested
4047+ # `adjust_{brightness,contrast,saturation,hue}` and `permute_channels` kernels.
4048+
4049+ @pytest .mark .parametrize (
4050+ "make_input" ,
4051+ [make_image_tensor , make_image_pil , make_image , make_video ],
4052+ )
4053+ @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
4054+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
4055+ def test_transform (self , make_input , dtype , device ):
4056+ if make_input is make_image_pil and not (dtype is torch .uint8 and device == "cpu" ):
4057+ pytest .skip (
4058+ "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
4059+ "will degenerate to that anyway."
4060+ )
4061+
4062+ check_transform (
4063+ transforms .RandomPhotometricDistort (
4064+ brightness = (0.3 , 0.4 ), contrast = (0.5 , 0.6 ), saturation = (0.7 , 0.8 ), hue = (- 0.1 , 0.2 ), p = 1
4065+ ),
4066+ make_input (dtype = dtype , device = device ),
4067+ )
0 commit comments