Skip to content

Commit c6f31e5

Browse files
committed
fix off-by-one bug in cross-attention-control (#1774)
prompt token sequences begin with a "beginning-of-sequence" marker <bos> and end with a repeated "end-of-sequence" marker <eos> - to make a default prompt length of <bos> + 75 prompt tokens + <eos>. the .swap() code was failing to take the column for <bos> at index 0 into account. the changes here do that, and also add extra handling for a single <eos> (which may be redundant but which is included for completeness). based on my understanding and some assumptions about how this all works, the reason .swap() nevertheless seemed to do the right thing, to some extent, is because over multiple steps the conditioning process in Stable Diffusion operates as a feedback loop. a change to token n-1 has flow-on effects to how the [1x4x64x64] latent tensor is modified by all the tokens after it, - and as the next step is processed, all the tokens before it as well. intuitively, a token's conditioning effects "echo" throughout the whole length of the prompt. so even though the token at n-1 was being edited when what the user actually wanted was to edit the token at n, it nevertheless still had some non-negligible effect, in roughly the right direction, often enough that it seemed like it was working properly.
1 parent f3570d8 commit c6f31e5

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

ldm/invoke/conditioning.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,13 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
7777
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
7878
original_token_count = 0
7979
edited_token_count = 0
80-
edit_opcodes = []
8180
edit_options = []
81+
edit_opcodes = []
82+
# beginning of sequence
83+
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
84+
edit_options.append(None)
85+
original_token_count += 1
86+
edited_token_count += 1
8287
for fragment in flattened_prompt.children:
8388
if type(fragment) is CrossAttentionControlSubstitute:
8489
original_prompt.append(fragment.original)
@@ -105,6 +110,12 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
105110
edit_options.append(None)
106111
original_token_count += count
107112
edited_token_count += count
113+
# end of sequence
114+
edit_opcodes.append(('equal', original_token_count, original_token_count+1, edited_token_count, edited_token_count+1))
115+
edit_options.append(None)
116+
original_token_count += 1
117+
edited_token_count += 1
118+
108119
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
109120
original_prompt,
110121
log_tokens=log_tokens,

0 commit comments

Comments
 (0)