@@ -61,6 +61,66 @@ def test_crop(self):
6161 assert sum2 > sum1 , "height: " + str (height ) + " width: " \
6262 + str (width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
6363
64+ def test_five_crop (self ):
65+ to_pil_image = transforms .ToPILImage ()
66+ h = random .randint (5 , 25 )
67+ w = random .randint (5 , 25 )
68+ for single_dim in [True , False ]:
69+ crop_h = random .randint (1 , h )
70+ crop_w = random .randint (1 , w )
71+ if single_dim :
72+ crop_h = min (crop_h , crop_w )
73+ crop_w = crop_h
74+ transform = transforms .FiveCrop (crop_h )
75+ else :
76+ transform = transforms .FiveCrop ((crop_h , crop_w ))
77+
78+ img = torch .FloatTensor (3 , h , w ).uniform_ ()
79+ results = transform (to_pil_image (img ))
80+
81+ assert len (results ) == 5
82+ for crop in results :
83+ assert crop .size == (crop_w , crop_h )
84+
85+ to_pil_image = transforms .ToPILImage ()
86+ tl = to_pil_image (img [:, 0 :crop_h , 0 :crop_w ])
87+ tr = to_pil_image (img [:, 0 :crop_h , w - crop_w :])
88+ bl = to_pil_image (img [:, h - crop_h :, 0 :crop_w ])
89+ br = to_pil_image (img [:, h - crop_h :, w - crop_w :])
90+ center = transforms .CenterCrop ((crop_h , crop_w ))(to_pil_image (img ))
91+ expected_output = (tl , tr , bl , br , center )
92+ assert results == expected_output
93+
94+ def test_ten_crop (self ):
95+ to_pil_image = transforms .ToPILImage ()
96+ h = random .randint (5 , 25 )
97+ w = random .randint (5 , 25 )
98+ for should_vflip in [True , False ]:
99+ for single_dim in [True , False ]:
100+ crop_h = random .randint (1 , h )
101+ crop_w = random .randint (1 , w )
102+ if single_dim :
103+ crop_h = min (crop_h , crop_w )
104+ crop_w = crop_h
105+ transform = transforms .TenCrop (crop_h , vflip = should_vflip )
106+ five_crop = transforms .FiveCrop (crop_h )
107+ else :
108+ transform = transforms .TenCrop ((crop_h , crop_w ), vflip = should_vflip )
109+ five_crop = transforms .FiveCrop ((crop_h , crop_w ))
110+
111+ img = to_pil_image (torch .FloatTensor (3 , h , w ).uniform_ ())
112+ results = transform (img )
113+ expected_output = five_crop (img )
114+ if should_vflip :
115+ vflipped_img = img .transpose (Image .FLIP_TOP_BOTTOM )
116+ expected_output += five_crop (vflipped_img )
117+ else :
118+ hflipped_img = img .transpose (Image .FLIP_LEFT_RIGHT )
119+ expected_output += five_crop (hflipped_img )
120+
121+ assert len (results ) == 10
122+ assert expected_output == results
123+
64124 def test_scale (self ):
65125 height = random .randint (24 , 32 ) * 2
66126 width = random .randint (24 , 32 ) * 2
0 commit comments