@@ -144,16 +144,31 @@ def get_dummy_components(self):
144144 }
145145 return components
146146
147- def get_dummy_inputs (self , device , seed = 0 ):
147+ def get_dummy_inputs (self , device , seed = 0 , img_res = 64 , output_pil = True ):
148148 # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
149- image = floats_tensor ((1 , 3 , 32 , 32 ), rng = random .Random (seed )).to (device )
150- image = image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
151- init_image = Image .fromarray (np .uint8 (image )).convert ("RGB" ).resize ((64 , 64 ))
152- mask_image = Image .fromarray (np .uint8 (image + 4 )).convert ("RGB" ).resize ((64 , 64 ))
149+ if output_pil :
150+ # Get random floats in [0, 1] as image
151+ image = floats_tensor ((1 , 3 , 32 , 32 ), rng = random .Random (seed )).to (device )
152+ image = image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
153+ mask_image = torch .ones_like (image )
154+ # Convert image and mask_image to [0, 255]
155+ image = 255 * image
156+ mask_image = 255 * mask_image
157+ # Convert to PIL image
158+ init_image = Image .fromarray (np .uint8 (image )).convert ("RGB" ).resize ((img_res , img_res ))
159+ mask_image = Image .fromarray (np .uint8 (mask_image )).convert ("RGB" ).resize ((img_res , img_res ))
160+ else :
161+ # Get random floats in [0, 1] as image with spatial size (img_res, img_res)
162+ image = floats_tensor ((1 , 3 , img_res , img_res ), rng = random .Random (seed )).to (device )
163+ # Convert image to [-1, 1]
164+ init_image = 2.0 * image - 1.0
165+ mask_image = torch .ones ((1 , 1 , img_res , img_res ), device = device )
166+
153167 if str (device ).startswith ("mps" ):
154168 generator = torch .manual_seed (seed )
155169 else :
156170 generator = torch .Generator (device = device ).manual_seed (seed )
171+
157172 inputs = {
158173 "prompt" : "A painting of a squirrel eating a burger" ,
159174 "image" : init_image ,
@@ -177,7 +192,7 @@ def test_stable_diffusion_inpaint(self):
177192 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
178193
179194 assert image .shape == (1 , 64 , 64 , 3 )
180- expected_slice = np .array ([0.4723 , 0.5731 , 0.3939 , 0.5441 , 0.5922 , 0.4392 , 0.5059 , 0.4651 , 0.4474 ])
195+ expected_slice = np .array ([0.4703 , 0.5697 , 0.3879 , 0.5470 , 0.6042 , 0.4413 , 0.5078 , 0.4728 , 0.4469 ])
181196
182197 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
183198
@@ -357,7 +372,7 @@ def test_stable_diffusion_inpaint(self):
357372 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
358373
359374 assert image .shape == (1 , 64 , 64 , 3 )
360- expected_slice = np .array ([0.4925 , 0.4967 , 0.4100 , 0.5234 , 0.5322 , 0.4532 , 0.5805 , 0.5877 , 0.4151 ])
375+ expected_slice = np .array ([0.6584 , 0.5424 , 0.5649 , 0.5449 , 0.5897 , 0.6111 , 0.5404 , 0.5463 , 0.5214 ])
361376
362377 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
363378
0 commit comments