@@ -67,50 +67,58 @@ def test_crop(self):
6767 self .compareTensorToPIL (img_tensor_cropped , pil_img_cropped )
6868
6969 def test_hsv2rgb (self ):
70+ scripted_fn = torch .jit .script (F_t ._hsv2rgb )
7071 shape = (3 , 100 , 150 )
71- for _ in range (20 ):
72- img = torch .rand (* shape , dtype = torch .float )
73- ft_img = F_t ._hsv2rgb (img ).permute (1 , 2 , 0 ).flatten (0 , 1 )
72+ for _ in range (10 ):
73+ hsv_img = torch .rand (* shape , dtype = torch .float , device = self .device )
74+ rgb_img = F_t ._hsv2rgb (hsv_img )
75+ ft_img = rgb_img .permute (1 , 2 , 0 ).flatten (0 , 1 )
7476
75- h , s , v , = img .unbind (0 )
76- h = h .flatten ().numpy ()
77- s = s .flatten ().numpy ()
78- v = v .flatten ().numpy ()
77+ h , s , v , = hsv_img .unbind (0 )
78+ h = h .flatten ().cpu (). numpy ()
79+ s = s .flatten ().cpu (). numpy ()
80+ v = v .flatten ().cpu (). numpy ()
7981
8082 rgb = []
8183 for h1 , s1 , v1 in zip (h , s , v ):
8284 rgb .append (colorsys .hsv_to_rgb (h1 , s1 , v1 ))
83-
84- colorsys_img = torch .tensor (rgb , dtype = torch .float32 )
85+ colorsys_img = torch .tensor (rgb , dtype = torch .float32 , device = self .device )
8586 max_diff = (ft_img - colorsys_img ).abs ().max ()
8687 self .assertLess (max_diff , 1e-5 )
8788
89+ s_rgb_img = scripted_fn (hsv_img )
90+ self .assertTrue (rgb_img .allclose (s_rgb_img ))
91+
8892 def test_rgb2hsv (self ):
93+ scripted_fn = torch .jit .script (F_t ._rgb2hsv )
8994 shape = (3 , 150 , 100 )
90- for _ in range (20 ):
91- img = torch .rand (* shape , dtype = torch .float )
92- ft_hsv_img = F_t ._rgb2hsv (img ).permute (1 , 2 , 0 ).flatten (0 , 1 )
95+ for _ in range (10 ):
96+ rgb_img = torch .rand (* shape , dtype = torch .float , device = self .device )
97+ hsv_img = F_t ._rgb2hsv (rgb_img )
98+ ft_hsv_img = hsv_img .permute (1 , 2 , 0 ).flatten (0 , 1 )
9399
94- r , g , b , = img .unbind (0 )
95- r = r .flatten ().numpy ()
96- g = g .flatten ().numpy ()
97- b = b .flatten ().numpy ()
100+ r , g , b , = rgb_img .unbind (0 )
101+ r = r .flatten ().cpu (). numpy ()
102+ g = g .flatten ().cpu (). numpy ()
103+ b = b .flatten ().cpu (). numpy ()
98104
99105 hsv = []
100106 for r1 , g1 , b1 in zip (r , g , b ):
101107 hsv .append (colorsys .rgb_to_hsv (r1 , g1 , b1 ))
102108
103- colorsys_img = torch .tensor (hsv , dtype = torch .float32 )
109+ colorsys_img = torch .tensor (hsv , dtype = torch .float32 , device = self . device )
104110
105111 ft_hsv_img_h , ft_hsv_img_sv = torch .split (ft_hsv_img , [1 , 2 ], dim = 1 )
106112 colorsys_img_h , colorsys_img_sv = torch .split (colorsys_img , [1 , 2 ], dim = 1 )
107113
108114 max_diff_h = ((colorsys_img_h * 2 * math .pi ).sin () - (ft_hsv_img_h * 2 * math .pi ).sin ()).abs ().max ()
109115 max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv ).abs ().max ()
110116 max_diff = max (max_diff_h , max_diff_sv )
111-
112117 self .assertLess (max_diff , 1e-5 )
113118
119+ s_hsv_img = scripted_fn (rgb_img )
120+ self .assertTrue (hsv_img .allclose (s_hsv_img ))
121+
114122 def test_rgb_to_grayscale (self ):
115123 script_rgb_to_grayscale = torch .jit .script (F .rgb_to_grayscale )
116124
0 commit comments