Skip to content

Commit c73e609

Browse files
Fix get_dummy_inputs for Stable Diffusion Inpaint Tests (#4845)
* Change StableDiffusionInpaintPipelineFastTests.get_dummy_inputs to produce a random image and a white mask_image. * Add dummy expected slices for the test_stable_diffusion_inpaint tests. * Remove print statement --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2fa4b3f commit c73e609

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,31 @@ def get_dummy_components(self):
144144
}
145145
return components
146146

147-
def get_dummy_inputs(self, device, seed=0):
147+
def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True):
148148
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
149-
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
150-
image = image.cpu().permute(0, 2, 3, 1)[0]
151-
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
152-
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
149+
if output_pil:
150+
# Get random floats in [0, 1] as image
151+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
152+
image = image.cpu().permute(0, 2, 3, 1)[0]
153+
mask_image = torch.ones_like(image)
154+
# Convert image and mask_image to [0, 255]
155+
image = 255 * image
156+
mask_image = 255 * mask_image
157+
# Convert to PIL image
158+
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((img_res, img_res))
159+
mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB").resize((img_res, img_res))
160+
else:
161+
# Get random floats in [0, 1] as image with spatial size (img_res, img_res)
162+
image = floats_tensor((1, 3, img_res, img_res), rng=random.Random(seed)).to(device)
163+
# Convert image to [-1, 1]
164+
init_image = 2.0 * image - 1.0
165+
mask_image = torch.ones((1, 1, img_res, img_res), device=device)
166+
153167
if str(device).startswith("mps"):
154168
generator = torch.manual_seed(seed)
155169
else:
156170
generator = torch.Generator(device=device).manual_seed(seed)
171+
157172
inputs = {
158173
"prompt": "A painting of a squirrel eating a burger",
159174
"image": init_image,
@@ -177,7 +192,7 @@ def test_stable_diffusion_inpaint(self):
177192
image_slice = image[0, -3:, -3:, -1]
178193

179194
assert image.shape == (1, 64, 64, 3)
180-
expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
195+
expected_slice = np.array([0.4703, 0.5697, 0.3879, 0.5470, 0.6042, 0.4413, 0.5078, 0.4728, 0.4469])
181196

182197
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
183198

@@ -357,7 +372,7 @@ def test_stable_diffusion_inpaint(self):
357372
image_slice = image[0, -3:, -3:, -1]
358373

359374
assert image.shape == (1, 64, 64, 3)
360-
expected_slice = np.array([0.4925, 0.4967, 0.4100, 0.5234, 0.5322, 0.4532, 0.5805, 0.5877, 0.4151])
375+
expected_slice = np.array([0.6584, 0.5424, 0.5649, 0.5449, 0.5897, 0.6111, 0.5404, 0.5463, 0.5214])
361376

362377
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
363378

0 commit comments

Comments
 (0)