@@ -247,18 +247,36 @@ def test_ten_crop(self):
247247 def test_pad (self ):
248248 script_fn = torch .jit .script (F_t .pad )
249249 tensor , pil_img = self ._create_data (7 , 8 )
250- for pad in [1 , [1 , ], [0 , 1 ], (2 , 2 ), [1 , 0 , 1 , 2 ]]:
251- padding_mode = "constant"
252- for fill in [0 , 10 , 20 ]:
253- pad_tensor = F_t .pad (tensor , pad , fill = fill , padding_mode = padding_mode )
254- pad_pil_img = F_pil .pad (pil_img , pad , fill = fill , padding_mode = padding_mode )
255- self .compareTensorToPIL (pad_tensor , pad_pil_img , msg = "{}, {}" .format (pad , fill ))
256- if isinstance (pad , int ):
257- script_pad = [pad , ]
258- else :
259- script_pad = pad
260- pad_tensor_script = script_fn (tensor , script_pad , fill = fill , padding_mode = padding_mode )
261- self .assertTrue (pad_tensor .equal (pad_tensor_script ), msg = "{}, {}" .format (pad , fill ))
250+
251+ for dt in [None , torch .float32 , torch .float64 ]:
252+ if dt is not None :
253+ # This is a trivial cast to float of uint8 data to test all cases
254+ tensor = tensor .to (dt )
255+ for pad in [2 , [3 , ], [0 , 3 ], (3 , 3 ), [4 , 2 , 4 , 3 ]]:
256+ configs = [
257+ {"padding_mode" : "constant" , "fill" : 0 },
258+ {"padding_mode" : "constant" , "fill" : 10 },
259+ {"padding_mode" : "constant" , "fill" : 20 },
260+ {"padding_mode" : "edge" },
261+ {"padding_mode" : "reflect" },
262+ ]
263+ for kwargs in configs :
264+ pad_tensor = F_t .pad (tensor , pad , ** kwargs )
265+ pad_pil_img = F_pil .pad (pil_img , pad , ** kwargs )
266+
267+ pad_tensor_8b = pad_tensor
268+ # we need to cast to uint8 to compare with PIL image
269+ if pad_tensor_8b .dtype != torch .uint8 :
270+ pad_tensor_8b = pad_tensor_8b .to (torch .uint8 )
271+
272+ self .compareTensorToPIL (pad_tensor_8b , pad_pil_img , msg = "{}, {}" .format (pad , kwargs ))
273+
274+ if isinstance (pad , int ):
275+ script_pad = [pad , ]
276+ else :
277+ script_pad = pad
278+ pad_tensor_script = script_fn (tensor , script_pad , ** kwargs )
279+ self .assertTrue (pad_tensor .equal (pad_tensor_script ), msg = "{}, {}" .format (pad , kwargs ))
262280
263281
264282if __name__ == '__main__' :
0 commit comments