Skip to content

Commit 64339af

Browse files
damian0815lstein
authored andcommitted
restrict to 75 tokens and correctly handle blends
1 parent 5d20f47 commit 64339af

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

ldm/invoke/conditioning.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,16 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
3636
return prompt, negative_prompt
3737

3838

39-
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
39+
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
4040
text_fragments = [x.text if type(x) is Fragment else
4141
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
4242
str(x))
4343
for x in parsed_prompt.children]
4444
text = " ".join(text_fragments)
4545
tokens = model.cond_stage_model.tokenizer.tokenize(text)
46+
if truncate_if_too_long:
47+
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
48+
tokens = tokens[0:max_tokens_length]
4649
return tokens
4750

4851

@@ -116,8 +119,12 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
116119
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
117120
cac_args = None
118121

119-
eos_token_index = 1
120-
if type(parsed_prompt) is not Blend:
122+
if type(parsed_prompt) is Blend:
123+
blend: Blend = parsed_prompt
124+
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
125+
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
126+
eos_token_index = len(longest_token_sequence)+1
127+
else:
121128
tokens = get_tokens_for_prompt(model, parsed_prompt)
122129
eos_token_index = len(tokens)+1
123130
return (

0 commit comments

Comments
 (0)