@@ -36,13 +36,16 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
3636 return prompt , negative_prompt
3737
3838
39- def get_tokens_for_prompt (model , parsed_prompt : FlattenedPrompt ) -> [str ]:
39+ def get_tokens_for_prompt (model , parsed_prompt : FlattenedPrompt , truncate_if_too_long = True ) -> [str ]:
4040 text_fragments = [x .text if type (x ) is Fragment else
4141 (" " .join ([f .text for f in x .original ]) if type (x ) is CrossAttentionControlSubstitute else
4242 str (x ))
4343 for x in parsed_prompt .children ]
4444 text = " " .join (text_fragments )
4545 tokens = model .cond_stage_model .tokenizer .tokenize (text )
46+ if truncate_if_too_long :
47+ max_tokens_length = model .cond_stage_model .max_length - 2 # typically 75
48+ tokens = tokens [0 :max_tokens_length ]
4649 return tokens
4750
4851
@@ -116,8 +119,12 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
116119 ">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored." )
117120 cac_args = None
118121
119- eos_token_index = 1
120- if type (parsed_prompt ) is not Blend :
122+ if type (parsed_prompt ) is Blend :
123+ blend : Blend = parsed_prompt
124+ all_token_sequences = [get_tokens_for_prompt (model , p ) for p in blend .prompts ]
125+ longest_token_sequence = max (all_token_sequences , key = lambda t : len (t ))
126+ eos_token_index = len (longest_token_sequence )+ 1
127+ else :
121128 tokens = get_tokens_for_prompt (model , parsed_prompt )
122129 eos_token_index = len (tokens )+ 1
123130 return (
0 commit comments