@@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
132132 """
133133 tokens = []
134134 weights = []
135+ truncated = False
135136 for text in prompt :
136137 texts_and_weights = parse_prompt_attention (text )
137138 text_token = []
@@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
140141 # tokenize and discard the starting and the ending token
141142 token = pipe .tokenizer (word ).input_ids [1 :- 1 ]
142143 text_token += token
143-
144144 # copy the weight by length of token
145145 text_weight += [weight ] * len (token )
146-
147146 # stop if the text is too long (longer than truncation limit)
148147 if len (text_token ) > max_length :
148+ truncated = True
149149 break
150-
151150 # truncate
152151 if len (text_token ) > max_length :
152+ truncated = True
153153 text_token = text_token [:max_length ]
154154 text_weight = text_weight [:max_length ]
155-
156155 tokens .append (text_token )
157156 weights .append (text_weight )
157+ if truncated :
158+ logger .warning ("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" )
158159 return tokens , weights
159160
160161
@@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
173174 if len (weights [i ]) == 0 :
174175 w = [1.0 ] * weights_length
175176 else :
176- for j in range (( len ( weights [ i ]) - 1 ) // chunk_length + 1 ):
177+ for j in range (max_embeddings_multiples ):
177178 w .append (1.0 ) # weight for starting token in this chunk
178- w += weights [i ][j * chunk_length : min (len (weights [i ]), (j + 1 ) * chunk_length )]
179+ w += weights [i ][j * ( chunk_length - 2 ) : min (len (weights [i ]), (j + 1 ) * ( chunk_length - 2 ) )]
179180 w .append (1.0 ) # weight for ending token in this chunk
180181 w += [1.0 ] * (weights_length - len (w ))
181182 weights [i ] = w [:]
@@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
184185
185186
186187def get_unweighted_text_embeddings (
187- pipe : DiffusionPipeline , text_input : torch .Tensor , chunk_length : int , no_boseos_middle : Optional [bool ] = True
188+ pipe : DiffusionPipeline ,
189+ text_input : torch .Tensor ,
190+ chunk_length : int ,
191+ no_boseos_middle : Optional [bool ] = True ,
188192):
189193 """
190194 When the length of tokens is a multiple of the capacity of the text encoder,
@@ -285,7 +289,8 @@ def get_weighted_text_embeddings(
285289 max_length = max (max_length , max ([len (token ) for token in uncond_tokens ]))
286290
287291 max_embeddings_multiples = min (
288- max_embeddings_multiples , (max_length - 1 ) // (pipe .tokenizer .model_max_length - 2 ) + 1
292+ max_embeddings_multiples ,
293+ (max_length - 1 ) // (pipe .tokenizer .model_max_length - 2 ) + 1 ,
289294 )
290295 max_embeddings_multiples = max (1 , max_embeddings_multiples )
291296 max_length = (pipe .tokenizer .model_max_length - 2 ) * max_embeddings_multiples + 2
@@ -317,12 +322,18 @@ def get_weighted_text_embeddings(
317322
318323 # get the embeddings
319324 text_embeddings = get_unweighted_text_embeddings (
320- pipe , prompt_tokens , pipe .tokenizer .model_max_length , no_boseos_middle = no_boseos_middle
325+ pipe ,
326+ prompt_tokens ,
327+ pipe .tokenizer .model_max_length ,
328+ no_boseos_middle = no_boseos_middle ,
321329 )
322330 prompt_weights = torch .tensor (prompt_weights , dtype = text_embeddings .dtype , device = pipe .device )
323331 if uncond_prompt is not None :
324332 uncond_embeddings = get_unweighted_text_embeddings (
325- pipe , uncond_tokens , pipe .tokenizer .model_max_length , no_boseos_middle = no_boseos_middle
333+ pipe ,
334+ uncond_tokens ,
335+ pipe .tokenizer .model_max_length ,
336+ no_boseos_middle = no_boseos_middle ,
326337 )
327338 uncond_weights = torch .tensor (uncond_weights , dtype = uncond_embeddings .dtype , device = pipe .device )
328339
@@ -632,16 +643,29 @@ def __call__(
632643 # Unlike in other pipelines, latents need to be generated in the target device
633644 # for 1-to-1 results reproducibility with the CompVis implementation.
634645 # However this currently doesn't work in `mps`.
635- latents_shape = (batch_size * num_images_per_prompt , self .unet .in_channels , height // 8 , width // 8 )
646+ latents_shape = (
647+ batch_size * num_images_per_prompt ,
648+ self .unet .in_channels ,
649+ height // 8 ,
650+ width // 8 ,
651+ )
636652
637653 if latents is None :
638654 if self .device .type == "mps" :
639655 # randn does not exist on mps
640- latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (
641- self .device
642- )
656+ latents = torch .randn (
657+ latents_shape ,
658+ generator = generator ,
659+ device = "cpu" ,
660+ dtype = latents_dtype ,
661+ ).to (self .device )
643662 else :
644- latents = torch .randn (latents_shape , generator = generator , device = self .device , dtype = latents_dtype )
663+ latents = torch .randn (
664+ latents_shape ,
665+ generator = generator ,
666+ device = self .device ,
667+ dtype = latents_dtype ,
668+ )
645669 else :
646670 if latents .shape != latents_shape :
647671 raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
@@ -684,11 +708,19 @@ def __call__(
684708 # add noise to latents using the timesteps
685709 if self .device .type == "mps" :
686710 # randn does not exist on mps
687- noise = torch .randn (init_latents .shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (
688- self .device
689- )
711+ noise = torch .randn (
712+ init_latents .shape ,
713+ generator = generator ,
714+ device = "cpu" ,
715+ dtype = latents_dtype ,
716+ ).to (self .device )
690717 else :
691- noise = torch .randn (init_latents .shape , generator = generator , device = self .device , dtype = latents_dtype )
718+ noise = torch .randn (
719+ init_latents .shape ,
720+ generator = generator ,
721+ device = self .device ,
722+ dtype = latents_dtype ,
723+ )
692724 latents = self .scheduler .add_noise (init_latents , noise , timesteps )
693725
694726 t_start = max (num_inference_steps - init_timestep + offset , 0 )
@@ -741,7 +773,8 @@ def __call__(
741773 self .device
742774 )
743775 image , has_nsfw_concept = self .safety_checker (
744- images = image , clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype )
776+ images = image ,
777+ clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype ),
745778 )
746779 else :
747780 has_nsfw_concept = None
0 commit comments