Skip to content

Commit c119dc4

Browse files
authored
allow multiple generations per prompt (#741)
* compute text embeds per prompt * don't repeat uncond prompts * repeat separatly * update image2image * fix repeat uncond embeds * adapt inpaint pipeline * ifx uncond tokens in img2img * add tests and fix ucond embeds in im2img and inpaint pipe
1 parent 367a671 commit c119dc4

File tree

4 files changed

+236
-13
lines changed

4 files changed

+236
-13
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __call__(
117117
num_inference_steps: int = 50,
118118
guidance_scale: float = 7.5,
119119
negative_prompt: Optional[Union[str, List[str]]] = None,
120+
num_images_per_prompt: Optional[int] = 1,
120121
eta: float = 0.0,
121122
generator: Optional[torch.Generator] = None,
122123
latents: Optional[torch.FloatTensor] = None,
@@ -148,6 +149,8 @@ def __call__(
148149
negative_prompt (`str` or `List[str]`, *optional*):
149150
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
150151
if `guidance_scale` is less than `1`).
152+
num_images_per_prompt (`int`, *optional*, defaults to 1):
153+
The number of images to generate per prompt.
151154
eta (`float`, *optional*, defaults to 0.0):
152155
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
153156
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -215,6 +218,9 @@ def __call__(
215218
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
216219
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
217220

221+
# duplicate text embeddings for each generation per prompt
222+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
223+
218224
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
219225
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
220226
# corresponds to doing no classifier free guidance.
@@ -223,14 +229,14 @@ def __call__(
223229
if do_classifier_free_guidance:
224230
uncond_tokens: List[str]
225231
if negative_prompt is None:
226-
uncond_tokens = [""] * batch_size
232+
uncond_tokens = [""]
227233
elif type(prompt) is not type(negative_prompt):
228234
raise TypeError(
229235
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
230236
" {type(prompt)}."
231237
)
232238
elif isinstance(negative_prompt, str):
233-
uncond_tokens = [negative_prompt] * batch_size
239+
uncond_tokens = [negative_prompt]
234240
elif batch_size != len(negative_prompt):
235241
raise ValueError(
236242
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -250,6 +256,9 @@ def __call__(
250256
)
251257
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
252258

259+
# duplicate unconditional embeddings for each generation per prompt
260+
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
261+
253262
# For classifier free guidance, we need to do two forward passes.
254263
# Here we concatenate the unconditional and text embeddings into a single batch
255264
# to avoid doing two forward passes
@@ -260,7 +269,7 @@ def __call__(
260269
# Unlike in other pipelines, latents need to be generated in the target device
261270
# for 1-to-1 results reproducibility with the CompVis implementation.
262271
# However this currently doesn't work in `mps`.
263-
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
272+
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
264273
latents_dtype = text_embeddings.dtype
265274
if latents is None:
266275
if self.device.type == "mps":

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __call__(
129129
num_inference_steps: Optional[int] = 50,
130130
guidance_scale: Optional[float] = 7.5,
131131
negative_prompt: Optional[Union[str, List[str]]] = None,
132+
num_images_per_prompt: Optional[int] = 1,
132133
eta: Optional[float] = 0.0,
133134
generator: Optional[torch.Generator] = None,
134135
output_type: Optional[str] = "pil",
@@ -164,6 +165,8 @@ def __call__(
164165
negative_prompt (`str` or `List[str]`, *optional*):
165166
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
166167
if `guidance_scale` is less than `1`).
168+
num_images_per_prompt (`int`, *optional*, defaults to 1):
169+
The number of images to generate per prompt.
167170
eta (`float`, *optional*, defaults to 0.0):
168171
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
169172
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -220,15 +223,15 @@ def __call__(
220223
init_latents = 0.18215 * init_latents
221224

222225
# expand init_latents for batch_size
223-
init_latents = torch.cat([init_latents] * batch_size)
226+
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
224227

225228
# get the original timestep using init_timestep
226229
offset = self.scheduler.config.get("steps_offset", 0)
227230
init_timestep = int(num_inference_steps * strength) + offset
228231
init_timestep = min(init_timestep, num_inference_steps)
229232

230233
timesteps = self.scheduler.timesteps[-init_timestep]
231-
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
234+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
232235

233236
# add noise to latents using the timesteps
234237
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@@ -252,6 +255,9 @@ def __call__(
252255
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
253256
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
254257

258+
# duplicate text embeddings for each generation per prompt
259+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
260+
255261
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
256262
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
257263
# corresponds to doing no classifier free guidance.
@@ -260,14 +266,14 @@ def __call__(
260266
if do_classifier_free_guidance:
261267
uncond_tokens: List[str]
262268
if negative_prompt is None:
263-
uncond_tokens = [""] * batch_size
269+
uncond_tokens = [""]
264270
elif type(prompt) is not type(negative_prompt):
265271
raise TypeError(
266272
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
267273
" {type(prompt)}."
268274
)
269275
elif isinstance(negative_prompt, str):
270-
uncond_tokens = [negative_prompt] * batch_size
276+
uncond_tokens = [negative_prompt]
271277
elif batch_size != len(negative_prompt):
272278
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
273279
else:
@@ -283,6 +289,9 @@ def __call__(
283289
)
284290
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
285291

292+
# duplicate unconditional embeddings for each generation per prompt
293+
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
294+
286295
# For classifier free guidance, we need to do two forward passes.
287296
# Here we concatenate the unconditional and text embeddings into a single batch
288297
# to avoid doing two forward passes

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __call__(
145145
num_inference_steps: Optional[int] = 50,
146146
guidance_scale: Optional[float] = 7.5,
147147
negative_prompt: Optional[Union[str, List[str]]] = None,
148+
num_images_per_prompt: Optional[int] = 1,
148149
eta: Optional[float] = 0.0,
149150
generator: Optional[torch.Generator] = None,
150151
output_type: Optional[str] = "pil",
@@ -184,6 +185,8 @@ def __call__(
184185
negative_prompt (`str` or `List[str]`, *optional*):
185186
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
186187
if `guidance_scale` is less than `1`).
188+
num_images_per_prompt (`int`, *optional*, defaults to 1):
189+
The number of images to generate per prompt.
187190
eta (`float`, *optional*, defaults to 0.0):
188191
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
189192
[`schedulers.DDIMScheduler`], will be ignored for others.
@@ -242,15 +245,15 @@ def __call__(
242245

243246
init_latents = 0.18215 * init_latents
244247

245-
# Expand init_latents for batch_size
246-
init_latents = torch.cat([init_latents] * batch_size)
248+
# Expand init_latents for batch_size and num_images_per_prompt
249+
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
247250
init_latents_orig = init_latents
248251

249252
# preprocess mask
250253
if not isinstance(mask_image, torch.FloatTensor):
251254
mask_image = preprocess_mask(mask_image)
252255
mask_image = mask_image.to(self.device)
253-
mask = torch.cat([mask_image] * batch_size)
256+
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
254257

255258
# check sizes
256259
if not mask.shape == init_latents.shape:
@@ -262,7 +265,7 @@ def __call__(
262265
init_timestep = min(init_timestep, num_inference_steps)
263266

264267
timesteps = self.scheduler.timesteps[-init_timestep]
265-
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
268+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
266269

267270
# add noise to latents using the timesteps
268271
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@@ -286,6 +289,9 @@ def __call__(
286289
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
287290
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
288291

292+
# duplicate text embeddings for each generation per prompt
293+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
294+
289295
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
290296
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
291297
# corresponds to doing no classifier free guidance.
@@ -294,14 +300,14 @@ def __call__(
294300
if do_classifier_free_guidance:
295301
uncond_tokens: List[str]
296302
if negative_prompt is None:
297-
uncond_tokens = [""] * batch_size
303+
uncond_tokens = [""]
298304
elif type(prompt) is not type(negative_prompt):
299305
raise TypeError(
300306
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
301307
" {type(prompt)}."
302308
)
303309
elif isinstance(negative_prompt, str):
304-
uncond_tokens = [negative_prompt] * batch_size
310+
uncond_tokens = [negative_prompt]
305311
elif batch_size != len(negative_prompt):
306312
raise ValueError(
307313
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -321,6 +327,9 @@ def __call__(
321327
)
322328
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
323329

330+
# duplicate unconditional embeddings for each generation per prompt
331+
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
332+
324333
# For classifier free guidance, we need to do two forward passes.
325334
# Here we concatenate the unconditional and text embeddings into a single batch
326335
# to avoid doing two forward passes

0 commit comments

Comments
 (0)