1313# limitations under the License.
1414
1515import inspect
16- from typing import List , Optional , Union
16+ from typing import List , Optional , Tuple , Union
1717
1818import torch
1919from torch .nn import functional as F
2020
2121from transformers import CLIPTextModelWithProjection , CLIPTokenizer
22+ from transformers .models .clip .modeling_clip import CLIPTextModelOutput
2223
2324from ...models import PriorTransformer , UNet2DConditionModel , UNet2DModel
2425from ...pipelines import DiffusionPipeline , ImagePipelineOutput
@@ -117,31 +118,44 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
117118 latents = latents * scheduler .init_noise_sigma
118119 return latents
119120
120- def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance ):
121- batch_size = len (prompt ) if isinstance (prompt , list ) else 1
122-
123- # get prompt text embeddings
124- text_inputs = self .tokenizer (
125- prompt ,
126- padding = "max_length" ,
127- max_length = self .tokenizer .model_max_length ,
128- return_tensors = "pt" ,
129- )
130- text_input_ids = text_inputs .input_ids
131- text_mask = text_inputs .attention_mask .bool ().to (device )
132-
133- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
134- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
135- logger .warning (
136- "The following part of your input was truncated because CLIP can only handle sequences up to"
137- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
121+ def _encode_prompt (
122+ self ,
123+ prompt ,
124+ device ,
125+ num_images_per_prompt ,
126+ do_classifier_free_guidance ,
127+ text_model_output : Optional [Union [CLIPTextModelOutput , Tuple ]] = None ,
128+ text_attention_mask : Optional [torch .Tensor ] = None ,
129+ ):
130+ if text_model_output is None :
131+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
132+ # get prompt text embeddings
133+ text_inputs = self .tokenizer (
134+ prompt ,
135+ padding = "max_length" ,
136+ max_length = self .tokenizer .model_max_length ,
137+ return_tensors = "pt" ,
138138 )
139- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
139+ text_input_ids = text_inputs .input_ids
140+ text_mask = text_inputs .attention_mask .bool ().to (device )
141+
142+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
143+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
144+ logger .warning (
145+ "The following part of your input was truncated because CLIP can only handle sequences up to"
146+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
147+ )
148+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
140149
141- text_encoder_output = self .text_encoder (text_input_ids .to (device ))
150+ text_encoder_output = self .text_encoder (text_input_ids .to (device ))
142151
143- text_embeddings = text_encoder_output .text_embeds
144- text_encoder_hidden_states = text_encoder_output .last_hidden_state
152+ text_embeddings = text_encoder_output .text_embeds
153+ text_encoder_hidden_states = text_encoder_output .last_hidden_state
154+
155+ else :
156+ batch_size = text_model_output [0 ].shape [0 ]
157+ text_embeddings , text_encoder_hidden_states = text_model_output [0 ], text_model_output [1 ]
158+ text_mask = text_attention_mask
145159
146160 text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
147161 text_encoder_hidden_states = text_encoder_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
@@ -150,11 +164,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
150164 if do_classifier_free_guidance :
151165 uncond_tokens = ["" ] * batch_size
152166
153- max_length = text_input_ids .shape [- 1 ]
154167 uncond_input = self .tokenizer (
155168 uncond_tokens ,
156169 padding = "max_length" ,
157- max_length = max_length ,
170+ max_length = self . tokenizer . model_max_length ,
158171 truncation = True ,
159172 return_tensors = "pt" ,
160173 )
@@ -235,7 +248,7 @@ def _execution_device(self):
235248 @torch .no_grad ()
236249 def __call__ (
237250 self ,
238- prompt : Union [str , List [str ]],
251+ prompt : Optional [ Union [str , List [str ]]] = None ,
239252 num_images_per_prompt : int = 1 ,
240253 prior_num_inference_steps : int = 25 ,
241254 decoder_num_inference_steps : int = 25 ,
@@ -244,6 +257,8 @@ def __call__(
244257 prior_latents : Optional [torch .FloatTensor ] = None ,
245258 decoder_latents : Optional [torch .FloatTensor ] = None ,
246259 super_res_latents : Optional [torch .FloatTensor ] = None ,
260+ text_model_output : Optional [Union [CLIPTextModelOutput , Tuple ]] = None ,
261+ text_attention_mask : Optional [torch .Tensor ] = None ,
247262 prior_guidance_scale : float = 4.0 ,
248263 decoder_guidance_scale : float = 8.0 ,
249264 output_type : Optional [str ] = "pil" ,
@@ -254,7 +269,8 @@ def __call__(
254269
255270 Args:
256271 prompt (`str` or `List[str]`):
257- The prompt or prompts to guide the image generation.
272+ The prompt or prompts to guide the image generation. This can only be left undefined if
273+ `text_model_output` and `text_attention_mask` is passed.
258274 num_images_per_prompt (`int`, *optional*, defaults to 1):
259275 The number of images to generate per prompt.
260276 prior_num_inference_steps (`int`, *optional*, defaults to 25):
@@ -287,26 +303,37 @@ def __call__(
287303 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
288304 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
289305 usually at the expense of lower image quality.
306+ text_model_output (`CLIPTextModelOutput`, *optional*):
307+ Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
308+ can be passed for tasks like text embedding interpolations. Make sure to also pass
309+ `text_attention_mask` in this case. `prompt` can the be left to `None`.
310+ text_attention_mask (`torch.Tensor`, *optional*):
311+ Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
312+ masks are necessary when passing `text_model_output`.
290313 output_type (`str`, *optional*, defaults to `"pil"`):
291314 The output format of the generated image. Choose between
292315 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
293316 return_dict (`bool`, *optional*, defaults to `True`):
294317 Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
295318 """
296- if isinstance (prompt , str ):
297- batch_size = 1
298- elif isinstance (prompt , list ):
299- batch_size = len (prompt )
319+ if prompt is not None :
320+ if isinstance (prompt , str ):
321+ batch_size = 1
322+ elif isinstance (prompt , list ):
323+ batch_size = len (prompt )
324+ else :
325+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
300326 else :
301- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
327+ batch_size = text_model_output [0 ].shape [0 ]
328+
302329 device = self ._execution_device
303330
304331 batch_size = batch_size * num_images_per_prompt
305332
306333 do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
307334
308335 text_embeddings , text_encoder_hidden_states , text_mask = self ._encode_prompt (
309- prompt , device , num_images_per_prompt , do_classifier_free_guidance
336+ prompt , device , num_images_per_prompt , do_classifier_free_guidance , text_model_output , text_attention_mask
310337 )
311338
312339 # prior
@@ -315,6 +342,7 @@ def __call__(
315342 prior_timesteps_tensor = self .prior_scheduler .timesteps
316343
317344 embedding_dim = self .prior .config .embedding_dim
345+
318346 prior_latents = self .prepare_latents (
319347 (batch_size , embedding_dim ),
320348 text_embeddings .dtype ,
@@ -378,6 +406,7 @@ def __call__(
378406 num_channels_latents = self .decoder .in_channels
379407 height = self .decoder .sample_size
380408 width = self .decoder .sample_size
409+
381410 decoder_latents = self .prepare_latents (
382411 (batch_size , num_channels_latents , height , width ),
383412 text_encoder_hidden_states .dtype ,
@@ -430,6 +459,7 @@ def __call__(
430459 channels = self .super_res_first .in_channels // 2
431460 height = self .super_res_first .sample_size
432461 width = self .super_res_first .sample_size
462+
433463 super_res_latents = self .prepare_latents (
434464 (batch_size , channels , height , width ),
435465 image_small .dtype ,
0 commit comments