@@ -64,22 +64,24 @@ def test_hflip(self):
6464
6565 def test_crop (self ):
6666 script_crop = torch .jit .script (F_t .crop )
67- img_tensor = torch .randint (0 , 255 , (3 , 16 , 16 ), dtype = torch .uint8 )
68- img_tensor_clone = img_tensor .clone ()
69- top = random .randint (0 , 15 )
70- left = random .randint (0 , 15 )
71- height = random .randint (1 , 16 - top )
72- width = random .randint (1 , 16 - left )
73- img_cropped = F_t .crop (img_tensor , top , left , height , width )
74- img_PIL = transforms .ToPILImage ()(img_tensor )
75- img_PIL_cropped = F .crop (img_PIL , top , left , height , width )
76- img_cropped_GT = transforms .ToTensor ()(img_PIL_cropped )
77- self .assertTrue (torch .equal (img_tensor , img_tensor_clone ))
78- self .assertTrue (torch .equal (img_cropped , (img_cropped_GT * 255 ).to (torch .uint8 )),
79- "functional_tensor crop not working" )
80- # scriptable function test
81- cropped_img_script = script_crop (img_tensor , top , left , height , width )
82- self .assertTrue (torch .equal (img_cropped , cropped_img_script ))
67+
68+ img_tensor , pil_img = self ._create_data (16 , 18 )
69+
70+ test_configs = [
71+ (1 , 2 , 4 , 5 ), # crop inside top-left corner
72+ (2 , 12 , 3 , 4 ), # crop inside top-right corner
73+ (8 , 3 , 5 , 6 ), # crop inside bottom-left corner
74+ (8 , 11 , 4 , 3 ), # crop inside bottom-right corner
75+ ]
76+
77+ for top , left , height , width in test_configs :
78+ pil_img_cropped = F .crop (pil_img , top , left , height , width )
79+
80+ img_tensor_cropped = F .crop (img_tensor , top , left , height , width )
81+ self .compareTensorToPIL (img_tensor_cropped , pil_img_cropped )
82+
83+ img_tensor_cropped = script_crop (img_tensor , top , left , height , width )
84+ self .compareTensorToPIL (img_tensor_cropped , pil_img_cropped )
8385
8486 def test_hsv2rgb (self ):
8587 shape = (3 , 100 , 150 )
@@ -198,71 +200,47 @@ def test_rgb_to_grayscale(self):
198200 self .assertTrue (torch .equal (grayscale_script , grayscale_tensor ))
199201
200202 def test_center_crop (self ):
201- script_center_crop = torch .jit .script (F_t .center_crop )
202- img_tensor = torch . randint ( 0 , 255 , ( 1 , 32 , 32 ), dtype = torch . uint8 )
203- img_tensor_clone = img_tensor . clone ( )
204- cropped_tensor = F_t . center_crop ( img_tensor , [ 10 , 10 ])
205- cropped_pil_image = F .center_crop (transforms . ToPILImage ()( img_tensor ) , [10 , 10 ])
206- cropped_pil_tensor = ( transforms . ToTensor ()( cropped_pil_image ) * 255 ). to ( torch . uint8 )
207- self . assertTrue ( torch . equal ( cropped_tensor , cropped_pil_tensor ) )
208- self .assertTrue ( torch . equal ( img_tensor , img_tensor_clone ) )
209- # scriptable function test
210- cropped_script = script_center_crop (img_tensor , [10 , 10 ])
211- self .assertTrue ( torch . equal ( cropped_script , cropped_tensor ) )
203+ script_center_crop = torch .jit .script (F .center_crop )
204+
205+ img_tensor , pil_img = self . _create_data ( 32 , 34 )
206+
207+ cropped_pil_image = F .center_crop (pil_img , [10 , 11 ])
208+
209+ cropped_tensor = F . center_crop ( img_tensor , [ 10 , 11 ] )
210+ self .compareTensorToPIL ( cropped_tensor , cropped_pil_image )
211+
212+ cropped_tensor = script_center_crop (img_tensor , [10 , 11 ])
213+ self .compareTensorToPIL ( cropped_tensor , cropped_pil_image )
212214
213215 def test_five_crop (self ):
214- script_five_crop = torch .jit .script (F_t .five_crop )
215- img_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 )
216- img_tensor_clone = img_tensor .clone ()
217- cropped_tensor = F_t .five_crop (img_tensor , [10 , 10 ])
218- cropped_pil_image = F .five_crop (transforms .ToPILImage ()(img_tensor ), [10 , 10 ])
219- self .assertTrue (torch .equal (cropped_tensor [0 ],
220- (transforms .ToTensor ()(cropped_pil_image [0 ]) * 255 ).to (torch .uint8 )))
221- self .assertTrue (torch .equal (cropped_tensor [1 ],
222- (transforms .ToTensor ()(cropped_pil_image [2 ]) * 255 ).to (torch .uint8 )))
223- self .assertTrue (torch .equal (cropped_tensor [2 ],
224- (transforms .ToTensor ()(cropped_pil_image [1 ]) * 255 ).to (torch .uint8 )))
225- self .assertTrue (torch .equal (cropped_tensor [3 ],
226- (transforms .ToTensor ()(cropped_pil_image [3 ]) * 255 ).to (torch .uint8 )))
227- self .assertTrue (torch .equal (cropped_tensor [4 ],
228- (transforms .ToTensor ()(cropped_pil_image [4 ]) * 255 ).to (torch .uint8 )))
229- self .assertTrue (torch .equal (img_tensor , img_tensor_clone ))
230- # scriptable function test
231- cropped_script = script_five_crop (img_tensor , [10 , 10 ])
232- for cropped_script_img , cropped_tensor_img in zip (cropped_script , cropped_tensor ):
233- self .assertTrue (torch .equal (cropped_script_img , cropped_tensor_img ))
216+ script_five_crop = torch .jit .script (F .five_crop )
217+
218+ img_tensor , pil_img = self ._create_data (32 , 34 )
219+
220+ cropped_pil_images = F .five_crop (pil_img , [10 , 11 ])
221+
222+ cropped_tensors = F .five_crop (img_tensor , [10 , 11 ])
223+ for i in range (5 ):
224+ self .compareTensorToPIL (cropped_tensors [i ], cropped_pil_images [i ])
225+
226+ cropped_tensors = script_five_crop (img_tensor , [10 , 11 ])
227+ for i in range (5 ):
228+ self .compareTensorToPIL (cropped_tensors [i ], cropped_pil_images [i ])
234229
235230 def test_ten_crop (self ):
236- script_ten_crop = torch .jit .script (F_t .ten_crop )
237- img_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 )
238- img_tensor_clone = img_tensor .clone ()
239- cropped_tensor = F_t .ten_crop (img_tensor , [10 , 10 ])
240- cropped_pil_image = F .ten_crop (transforms .ToPILImage ()(img_tensor ), [10 , 10 ])
241- self .assertTrue (torch .equal (cropped_tensor [0 ],
242- (transforms .ToTensor ()(cropped_pil_image [0 ]) * 255 ).to (torch .uint8 )))
243- self .assertTrue (torch .equal (cropped_tensor [1 ],
244- (transforms .ToTensor ()(cropped_pil_image [2 ]) * 255 ).to (torch .uint8 )))
245- self .assertTrue (torch .equal (cropped_tensor [2 ],
246- (transforms .ToTensor ()(cropped_pil_image [1 ]) * 255 ).to (torch .uint8 )))
247- self .assertTrue (torch .equal (cropped_tensor [3 ],
248- (transforms .ToTensor ()(cropped_pil_image [3 ]) * 255 ).to (torch .uint8 )))
249- self .assertTrue (torch .equal (cropped_tensor [4 ],
250- (transforms .ToTensor ()(cropped_pil_image [4 ]) * 255 ).to (torch .uint8 )))
251- self .assertTrue (torch .equal (cropped_tensor [5 ],
252- (transforms .ToTensor ()(cropped_pil_image [5 ]) * 255 ).to (torch .uint8 )))
253- self .assertTrue (torch .equal (cropped_tensor [6 ],
254- (transforms .ToTensor ()(cropped_pil_image [7 ]) * 255 ).to (torch .uint8 )))
255- self .assertTrue (torch .equal (cropped_tensor [7 ],
256- (transforms .ToTensor ()(cropped_pil_image [6 ]) * 255 ).to (torch .uint8 )))
257- self .assertTrue (torch .equal (cropped_tensor [8 ],
258- (transforms .ToTensor ()(cropped_pil_image [8 ]) * 255 ).to (torch .uint8 )))
259- self .assertTrue (torch .equal (cropped_tensor [9 ],
260- (transforms .ToTensor ()(cropped_pil_image [9 ]) * 255 ).to (torch .uint8 )))
261- self .assertTrue (torch .equal (img_tensor , img_tensor_clone ))
262- # scriptable function test
263- cropped_script = script_ten_crop (img_tensor , [10 , 10 ])
264- for cropped_script_img , cropped_tensor_img in zip (cropped_script , cropped_tensor ):
265- self .assertTrue (torch .equal (cropped_script_img , cropped_tensor_img ))
231+ script_ten_crop = torch .jit .script (F .ten_crop )
232+
233+ img_tensor , pil_img = self ._create_data (32 , 34 )
234+
235+ cropped_pil_images = F .ten_crop (pil_img , [10 , 11 ])
236+
237+ cropped_tensors = F .ten_crop (img_tensor , [10 , 11 ])
238+ for i in range (10 ):
239+ self .compareTensorToPIL (cropped_tensors [i ], cropped_pil_images [i ])
240+
241+ cropped_tensors = script_ten_crop (img_tensor , [10 , 11 ])
242+ for i in range (10 ):
243+ self .compareTensorToPIL (cropped_tensors [i ], cropped_pil_images [i ])
266244
267245 def test_pad (self ):
268246 script_fn = torch .jit .script (F_t .pad )
0 commit comments