Skip to content

Commit 1ae7ac3

Browse files
committed
move test num_images_per_prompt to pipeline mixin
1 parent dde7095 commit 1ae7ac3

14 files changed

+71
-325
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,30 @@ def check_inputs(
513513
f" {negative_prompt_embeds.shape}."
514514
)
515515

516-
if (indices is None) or (indices is not None and not isinstance(indices, List)):
517-
raise ValueError(f"`indices` has to be a list but is {type(indices)}")
516+
indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int)
517+
indices_is_list_list_ints = (
518+
isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int)
519+
)
520+
521+
if not indices_is_list_ints and not indices_is_list_list_ints:
522+
raise TypeError("`indices` must be a list of ints or a list of a list of ints")
523+
524+
if indices_is_list_ints:
525+
indices_batch_size = 1
526+
elif indices_is_list_list_ints:
527+
indices_batch_size = len(indices)
528+
529+
if prompt is not None and isinstance(prompt, str):
530+
prompt_batch_size = 1
531+
elif prompt is not None and isinstance(prompt, list):
532+
prompt_batch_size = len(prompt)
533+
elif prompt_embeds is not None:
534+
prompt_batch_size = prompt_embeds.shape[0]
535+
536+
if indices_batch_size != prompt_batch_size:
537+
raise ValueError(
538+
f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}"
539+
)
518540

519541
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
520542
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
@@ -671,7 +693,7 @@ def get_indices(self, prompt: str) -> Dict[str, int]:
671693
def __call__(
672694
self,
673695
prompt: Union[str, List[str]],
674-
token_indices: List[int],
696+
token_indices: Union[List[int], List[List[int]]],
675697
height: Optional[int] = None,
676698
width: Optional[int] = None,
677699
num_inference_steps: int = 50,
@@ -847,7 +869,9 @@ def __call__(
847869

848870
if isinstance(token_indices[0], int):
849871
token_indices = [token_indices]
872+
850873
indices = []
874+
851875
for ind in token_indices:
852876
indices = indices + [ind] * num_images_per_prompt
853877

tests/pipelines/paint_by_example/test_paint_by_example.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,8 @@ def test_paint_by_example_image_tensor(self):
163163
assert out_1.shape == (1, 64, 64, 3)
164164
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
165165

166-
def test_paint_by_example_inpaint_with_num_images_per_prompt(self):
167-
device = "cpu"
168-
pipe = PaintByExamplePipeline(**self.get_dummy_components())
169-
pipe = pipe.to(device)
170-
pipe.set_progress_bar_config(disable=None)
171-
172-
inputs = self.get_dummy_inputs()
173-
174-
images = pipe(**inputs, num_images_per_prompt=2).images
175-
176-
# check if the output is a list of 2 images
177-
assert len(images) == 2
166+
def test_num_images_per_prompt(self):
167+
self._test_num_images_per_prompt(prompt_key=["image", "example_image", "mask_image"])
178168

179169

180170
@slow

tests/pipelines/stable_diffusion/test_cycle_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def test_stable_diffusion_cycle_fp16(self):
149149

150150
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
151151

152+
def test_num_images_per_prompt(self):
153+
self._test_num_images_per_prompt(prompt_key=["prompt", "source_prompt", "image"])
154+
152155

153156
@slow
154157
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -451,43 +451,6 @@ def test_stable_diffusion_negative_prompt(self):
451451

452452
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
453453

454-
def test_stable_diffusion_num_images_per_prompt(self):
455-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
456-
components = self.get_dummy_components()
457-
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
458-
sd_pipe = StableDiffusionPipeline(**components)
459-
sd_pipe = sd_pipe.to(device)
460-
sd_pipe.set_progress_bar_config(disable=None)
461-
462-
prompt = "A painting of a squirrel eating a burger"
463-
464-
# test num_images_per_prompt=1 (default)
465-
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
466-
467-
assert images.shape == (1, 64, 64, 3)
468-
469-
# test num_images_per_prompt=1 (default) for batch of prompts
470-
batch_size = 2
471-
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
472-
473-
assert images.shape == (batch_size, 64, 64, 3)
474-
475-
# test num_images_per_prompt for single prompt
476-
num_images_per_prompt = 2
477-
images = sd_pipe(
478-
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
479-
).images
480-
481-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
482-
483-
# test num_images_per_prompt for batch of prompts
484-
batch_size = 2
485-
images = sd_pipe(
486-
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
487-
).images
488-
489-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
490-
491454
def test_stable_diffusion_long_prompt(self):
492455
components = self.get_dummy_components()
493456
components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)

tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -140,41 +140,8 @@ def test_stable_diffusion_img_variation_multiple_images(self):
140140

141141
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
142142

143-
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
144-
device = "cpu"
145-
components = self.get_dummy_components()
146-
sd_pipe = StableDiffusionImageVariationPipeline(**components)
147-
sd_pipe = sd_pipe.to(device)
148-
sd_pipe.set_progress_bar_config(disable=None)
149-
150-
# test num_images_per_prompt=1 (default)
151-
inputs = self.get_dummy_inputs(device)
152-
images = sd_pipe(**inputs).images
153-
154-
assert images.shape == (1, 64, 64, 3)
155-
156-
# test num_images_per_prompt=1 (default) for batch of images
157-
batch_size = 2
158-
inputs = self.get_dummy_inputs(device)
159-
inputs["image"] = batch_size * [inputs["image"]]
160-
images = sd_pipe(**inputs).images
161-
162-
assert images.shape == (batch_size, 64, 64, 3)
163-
164-
# test num_images_per_prompt for single prompt
165-
num_images_per_prompt = 2
166-
inputs = self.get_dummy_inputs(device)
167-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
168-
169-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
170-
171-
# test num_images_per_prompt for batch of prompts
172-
batch_size = 2
173-
inputs = self.get_dummy_inputs(device)
174-
inputs["image"] = batch_size * [inputs["image"]]
175-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
176-
177-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
143+
def test_num_images_per_prompt(self):
144+
self._test_num_images_per_prompt(prompt_key="image")
178145

179146

180147
@slow

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -177,42 +177,6 @@ def test_stable_diffusion_img2img_k_lms(self):
177177

178178
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
179179

180-
def test_stable_diffusion_img2img_num_images_per_prompt(self):
181-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
182-
components = self.get_dummy_components()
183-
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
184-
sd_pipe = sd_pipe.to(device)
185-
sd_pipe.set_progress_bar_config(disable=None)
186-
187-
# test num_images_per_prompt=1 (default)
188-
inputs = self.get_dummy_inputs(device)
189-
images = sd_pipe(**inputs).images
190-
191-
assert images.shape == (1, 32, 32, 3)
192-
193-
# test num_images_per_prompt=1 (default) for batch of prompts
194-
batch_size = 2
195-
inputs = self.get_dummy_inputs(device)
196-
inputs["prompt"] = [inputs["prompt"]] * batch_size
197-
images = sd_pipe(**inputs).images
198-
199-
assert images.shape == (batch_size, 32, 32, 3)
200-
201-
# test num_images_per_prompt for single prompt
202-
num_images_per_prompt = 2
203-
inputs = self.get_dummy_inputs(device)
204-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
205-
206-
assert images.shape == (num_images_per_prompt, 32, 32, 3)
207-
208-
# test num_images_per_prompt for batch of prompts
209-
batch_size = 2
210-
inputs = self.get_dummy_inputs(device)
211-
inputs["prompt"] = [inputs["prompt"]] * batch_size
212-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
213-
214-
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
215-
216180

217181
@slow
218182
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,6 @@ def test_stable_diffusion_inpaint_image_tensor(self):
148148
assert out_pil.shape == (1, 64, 64, 3)
149149
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
150150

151-
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
152-
device = "cpu"
153-
components = self.get_dummy_components()
154-
sd_pipe = StableDiffusionInpaintPipeline(**components)
155-
sd_pipe = sd_pipe.to(device)
156-
sd_pipe.set_progress_bar_config(disable=None)
157-
158-
inputs = self.get_dummy_inputs(device)
159-
images = sd_pipe(**inputs, num_images_per_prompt=2).images
160-
161-
# check if the output is a list of 2 images
162-
assert len(images) == 2
163-
164151

165152
@slow
166153
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -188,42 +188,6 @@ def test_stable_diffusion_pix2pix_euler(self):
188188

189189
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
190190

191-
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
192-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
193-
components = self.get_dummy_components()
194-
sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)
195-
sd_pipe = sd_pipe.to(device)
196-
sd_pipe.set_progress_bar_config(disable=None)
197-
198-
# test num_images_per_prompt=1 (default)
199-
inputs = self.get_dummy_inputs(device)
200-
images = sd_pipe(**inputs).images
201-
202-
assert images.shape == (1, 32, 32, 3)
203-
204-
# test num_images_per_prompt=1 (default) for batch of prompts
205-
batch_size = 2
206-
inputs = self.get_dummy_inputs(device)
207-
inputs["prompt"] = [inputs["prompt"]] * batch_size
208-
images = sd_pipe(**inputs).images
209-
210-
assert images.shape == (batch_size, 32, 32, 3)
211-
212-
# test num_images_per_prompt for single prompt
213-
num_images_per_prompt = 2
214-
inputs = self.get_dummy_inputs(device)
215-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
216-
217-
assert images.shape == (num_images_per_prompt, 32, 32, 3)
218-
219-
# test num_images_per_prompt for batch of prompts
220-
batch_size = 2
221-
inputs = self.get_dummy_inputs(device)
222-
inputs["prompt"] = [inputs["prompt"]] * batch_size
223-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
224-
225-
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
226-
227191

228192
@slow
229193
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -174,42 +174,6 @@ def test_stable_diffusion_panorama_pndm(self):
174174
with self.assertRaises(ValueError):
175175
_ = sd_pipe(**inputs).images
176176

177-
def test_stable_diffusion_panorama_num_images_per_prompt(self):
178-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
179-
components = self.get_dummy_components()
180-
sd_pipe = StableDiffusionPanoramaPipeline(**components)
181-
sd_pipe = sd_pipe.to(device)
182-
sd_pipe.set_progress_bar_config(disable=None)
183-
184-
# test num_images_per_prompt=1 (default)
185-
inputs = self.get_dummy_inputs(device)
186-
images = sd_pipe(**inputs).images
187-
188-
assert images.shape == (1, 64, 64, 3)
189-
190-
# test num_images_per_prompt=1 (default) for batch of prompts
191-
batch_size = 2
192-
inputs = self.get_dummy_inputs(device)
193-
inputs["prompt"] = [inputs["prompt"]] * batch_size
194-
images = sd_pipe(**inputs).images
195-
196-
assert images.shape == (batch_size, 64, 64, 3)
197-
198-
# test num_images_per_prompt for single prompt
199-
num_images_per_prompt = 2
200-
inputs = self.get_dummy_inputs(device)
201-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
202-
203-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
204-
205-
# test num_images_per_prompt for batch of prompts
206-
batch_size = 2
207-
inputs = self.get_dummy_inputs(device)
208-
inputs["prompt"] = [inputs["prompt"]] * batch_size
209-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
210-
211-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
212-
213177

214178
@slow
215179
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -195,34 +195,6 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self):
195195

196196
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
197197

198-
def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self):
199-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
200-
components = self.get_dummy_components()
201-
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
202-
sd_pipe = sd_pipe.to(device)
203-
sd_pipe.set_progress_bar_config(disable=None)
204-
205-
# test num_images_per_prompt=1 (default)
206-
inputs = self.get_dummy_inputs(device)
207-
images = sd_pipe(**inputs).images
208-
209-
assert images.shape == (1, 64, 64, 3)
210-
211-
# test num_images_per_prompt=2 for a single prompt
212-
num_images_per_prompt = 2
213-
inputs = self.get_dummy_inputs(device)
214-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
215-
216-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
217-
218-
# test num_images_per_prompt for batch of prompts
219-
batch_size = 2
220-
inputs = self.get_dummy_inputs(device)
221-
inputs["prompt"] = [inputs["prompt"]] * batch_size
222-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
223-
224-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
225-
226198

227199
@slow
228200
@require_torch_gpu

0 commit comments

Comments
 (0)