@@ -3947,6 +3947,61 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue):
39473947 assert mae < 2
39483948
39493949
3950+ class TestRgbToGrayscale :
3951+ @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
3952+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
3953+ def test_kernel_image (self , dtype , device ):
3954+ check_kernel (F .rgb_to_grayscale_image , make_image (dtype = dtype , device = device ))
3955+
3956+ @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image ])
3957+ def test_functional (self , make_input ):
3958+ check_functional (F .rgb_to_grayscale , make_input ())
3959+
3960+ @pytest .mark .parametrize (
3961+ ("kernel" , "input_type" ),
3962+ [
3963+ (F .rgb_to_grayscale_image , torch .Tensor ),
3964+ (F ._rgb_to_grayscale_image_pil , PIL .Image .Image ),
3965+ (F .rgb_to_grayscale_image , tv_tensors .Image ),
3966+ ],
3967+ )
3968+ def test_functional_signature (self , kernel , input_type ):
3969+ check_functional_kernel_signature_match (F .rgb_to_grayscale , kernel = kernel , input_type = input_type )
3970+
3971+ @pytest .mark .parametrize ("transform" , [transforms .Grayscale (), transforms .RandomGrayscale (p = 1 )])
3972+ @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image ])
3973+ def test_transform (self , transform , make_input ):
3974+ check_transform (transform , make_input ())
3975+
3976+ @pytest .mark .parametrize ("num_output_channels" , [1 , 3 ])
3977+ @pytest .mark .parametrize ("fn" , [F .rgb_to_grayscale , transform_cls_to_functional (transforms .Grayscale )])
3978+ def test_image_correctness (self , num_output_channels , fn ):
3979+ image = make_image (dtype = torch .uint8 , device = "cpu" )
3980+
3981+ actual = fn (image , num_output_channels = num_output_channels )
3982+ expected = F .to_image (F .rgb_to_grayscale (F .to_pil_image (image ), num_output_channels = num_output_channels ))
3983+
3984+ assert_equal (actual , expected , rtol = 0 , atol = 1 )
3985+
3986+ @pytest .mark .parametrize ("num_input_channels" , [1 , 3 ])
3987+ def test_random_transform_correctness (self , num_input_channels ):
3988+ image = make_image (
3989+ color_space = {
3990+ 1 : "GRAY" ,
3991+ 3 : "RGB" ,
3992+ }[num_input_channels ],
3993+ dtype = torch .uint8 ,
3994+ device = "cpu" ,
3995+ )
3996+
3997+ transform = transforms .RandomGrayscale (p = 1 )
3998+
3999+ actual = transform (image )
4000+ expected = F .to_image (F .rgb_to_grayscale (F .to_pil_image (image ), num_output_channels = num_input_channels ))
4001+
4002+ assert_equal (actual , expected , rtol = 0 , atol = 1 )
4003+
4004+
39504005class TestRandomZoomOut :
39514006 @pytest .mark .parametrize (
39524007 "make_input" ,
0 commit comments