@@ -145,6 +145,7 @@ def __call__(
145145 num_inference_steps : Optional [int ] = 50 ,
146146 guidance_scale : Optional [float ] = 7.5 ,
147147 negative_prompt : Optional [Union [str , List [str ]]] = None ,
148+ num_images_per_prompt : Optional [int ] = 1 ,
148149 eta : Optional [float ] = 0.0 ,
149150 generator : Optional [torch .Generator ] = None ,
150151 output_type : Optional [str ] = "pil" ,
@@ -184,6 +185,8 @@ def __call__(
184185 negative_prompt (`str` or `List[str]`, *optional*):
185186 The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
186187 if `guidance_scale` is less than `1`).
188+ num_images_per_prompt (`int`, *optional*, defaults to 1):
189+ The number of images to generate per prompt.
187190 eta (`float`, *optional*, defaults to 0.0):
188191 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
189192 [`schedulers.DDIMScheduler`], will be ignored for others.
@@ -242,15 +245,15 @@ def __call__(
242245
243246 init_latents = 0.18215 * init_latents
244247
245- # Expand init_latents for batch_size
246- init_latents = torch .cat ([init_latents ] * batch_size )
248+ # Expand init_latents for batch_size and num_images_per_prompt
249+ init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
247250 init_latents_orig = init_latents
248251
249252 # preprocess mask
250253 if not isinstance (mask_image , torch .FloatTensor ):
251254 mask_image = preprocess_mask (mask_image )
252255 mask_image = mask_image .to (self .device )
253- mask = torch .cat ([mask_image ] * batch_size )
256+ mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
254257
255258 # check sizes
256259 if not mask .shape == init_latents .shape :
@@ -262,7 +265,7 @@ def __call__(
262265 init_timestep = min (init_timestep , num_inference_steps )
263266
264267 timesteps = self .scheduler .timesteps [- init_timestep ]
265- timesteps = torch .tensor ([timesteps ] * batch_size , device = self .device )
268+ timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
266269
267270 # add noise to latents using the timesteps
268271 noise = torch .randn (init_latents .shape , generator = generator , device = self .device )
@@ -286,6 +289,9 @@ def __call__(
286289 text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
287290 text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
288291
292+ # duplicate text embeddings for each generation per prompt
293+ text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
294+
289295 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
290296 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
291297 # corresponds to doing no classifier free guidance.
@@ -294,14 +300,14 @@ def __call__(
294300 if do_classifier_free_guidance :
295301 uncond_tokens : List [str ]
296302 if negative_prompt is None :
297- uncond_tokens = ["" ] * batch_size
303+ uncond_tokens = ["" ]
298304 elif type (prompt ) is not type (negative_prompt ):
299305 raise TypeError (
300306 "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
301307 " {type(prompt)}."
302308 )
303309 elif isinstance (negative_prompt , str ):
304- uncond_tokens = [negative_prompt ] * batch_size
310+ uncond_tokens = [negative_prompt ]
305311 elif batch_size != len (negative_prompt ):
306312 raise ValueError (
307313 f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
@@ -321,6 +327,9 @@ def __call__(
321327 )
322328 uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
323329
330+ # duplicate unconditional embeddings for each generation per prompt
331+ uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
332+
324333 # For classifier free guidance, we need to do two forward passes.
325334 # Here we concatenate the unconditional and text embeddings into a single batch
326335 # to avoid doing two forward passes
0 commit comments