Skip to content

Commit b035ac8

Browse files
committed
add missing position_embeddings
1 parent bafb721 commit b035ac8

File tree

2 files changed

+79
-38
lines changed

2 files changed

+79
-38
lines changed

ldm/modules/textual_inversion_manager.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def embedding_vector_length(self) -> int:
2121
return self.embedding.shape[0]
2222

2323
class TextualInversionManager():
24-
def __init__(self, clip_embedder: FrozenCLIPEmbedder, full_precision: bool):
24+
def __init__(self, clip_embedder: FrozenCLIPEmbedder, full_precision: bool=True):
2525
self.clip_embedder = clip_embedder
2626
self.full_precision = full_precision
2727
self.hf_concepts_library = HuggingFaceConceptsLibrary()
@@ -169,17 +169,17 @@ def overwrite_textual_inversion_embeddings(self, prompt_token_ids: Union[torch.T
169169
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
170170
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
171171
overwritten_prompt_embeddings = prompt_embeddings.clone()
172-
for i, token_id in enumerate(prompt_token_ids):
173-
if token_id == pad_token_id:
174-
continue
175-
if token_id in textual_inversion_token_ids:
176-
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
177-
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
178-
count_to_overwrite = end_index - i
179-
for j in range(0, count_to_overwrite):
180-
# only overwrite the textual inversion token id or the padding token id
181-
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
182-
break
183-
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
172+
173+
indices_of_textual_inversion_tokens_in_prompt = [index for index in range(0, len(prompt_token_ids)) if prompt_token_ids[index] in textual_inversion_token_ids]
174+
eos_marker_index = self.clip_embedder.max_length-1
175+
for i in indices_of_textual_inversion_tokens_in_prompt:
176+
token_id = prompt_token_ids[i]
177+
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
178+
# don't overwrite the eos marker
179+
after_end_index = min(i + textual_inversion.embedding_vector_length, eos_marker_index)
180+
actual_count_to_overwrite = after_end_index - i
181+
position_embeddings = self.clip_embedder.position_embedding(torch.arange(i, after_end_index, dtype=int))
182+
embeddings_to_write = position_embeddings + textual_inversion.embedding[0:actual_count_to_overwrite]
183+
overwritten_prompt_embeddings[i:i+actual_count_to_overwrite] = embeddings_to_write
184184

185185
return overwritten_prompt_embeddings

tests/test_textual_inversion.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11

22
import unittest
3+
from typing import Union
34

45
import torch
56

6-
from ldm.modules.embedding_manager import TextualInversionManager
7+
from ldm.modules.textual_inversion_manager import TextualInversionManager
78

89

910
KNOWN_WORDS = ['a', 'b', 'c']
@@ -53,7 +54,16 @@ def __init__(self):
5354
self.max_length = 77
5455
self.transformer = DummyTransformer()
5556
self.tokenizer = DummyTokenizer()
57+
self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32)
5658

59+
def position_embedding(self, indices: Union[list,torch.Tensor]):
60+
if type(indices) is list:
61+
indices = torch.tensor(indices, dtype=int)
62+
return torch.index_select(self.position_embeddings_tensor, 0, indices)
63+
64+
65+
def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool:
66+
return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices))
5767

5868
class TextualInversionManagerTestCase(unittest.TestCase):
5969

@@ -270,7 +280,7 @@ def test_overwrite_textual_inversion_1v_single(self):
270280
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
271281
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
272282
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
273-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v[0]))
283+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [4], test_embedding_1v))
274284
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
275285

276286
# at the start
@@ -283,7 +293,7 @@ def test_overwrite_textual_inversion_1v_single(self):
283293
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
284294
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
285295
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
286-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v[0]))
296+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [1], test_embedding_1v))
287297
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2:77], default_prompt_embeddings[2:77]))
288298

289299
# in the middle
@@ -296,7 +306,7 @@ def test_overwrite_textual_inversion_1v_single(self):
296306
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
297307
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
298308
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
299-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v[0]))
309+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [2], test_embedding_1v))
300310
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77]))
301311

302312

@@ -326,8 +336,8 @@ def test_overwrite_textual_inversion_1v_multiple(self):
326336
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
327337
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
328338
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
329-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_1[0]))
330-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5], test_embedding_1v_2[0]))
339+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [4], test_embedding_1v_1))
340+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [5], test_embedding_1v_2))
331341
self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77]))
332342

333343
# at the start
@@ -340,8 +350,10 @@ def test_overwrite_textual_inversion_1v_multiple(self):
340350
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
341351
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
342352
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
343-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1], test_embedding_1v_1[0]))
344-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_2[0]))
353+
self.assertTrue(
354+
was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [1], test_embedding_1v_1))
355+
self.assertTrue(
356+
was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [2], test_embedding_1v_2))
345357
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3:77], default_prompt_embeddings[3:77]))
346358

347359
# clumped in the middle
@@ -354,8 +366,10 @@ def test_overwrite_textual_inversion_1v_multiple(self):
354366
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
355367
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
356368
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
357-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
358-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], test_embedding_1v_2[0]))
369+
self.assertTrue(
370+
was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [2], test_embedding_1v_1))
371+
self.assertTrue(
372+
was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [3], test_embedding_1v_2))
359373
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:77], default_prompt_embeddings[4:77]))
360374

361375
# scattered
@@ -368,9 +382,11 @@ def test_overwrite_textual_inversion_1v_multiple(self):
368382
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
369383
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
370384
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
371-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2], test_embedding_1v_1[0]))
385+
self.assertTrue(
386+
was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [2], test_embedding_1v_1))
372387
self.assertTrue(torch.equal(overwritten_prompt_embeddings[3], default_prompt_embeddings[3]))
373-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4], test_embedding_1v_2[0]))
388+
self.assertTrue(
389+
was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings, [4], test_embedding_1v_2))
374390
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
375391

376392
def test_overwrite_textual_inversion_4v_single(self):
@@ -393,7 +409,9 @@ def test_overwrite_textual_inversion_4v_single(self):
393409
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
394410
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
395411
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:4], default_prompt_embeddings[0:4]))
396-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[4:8], test_embedding_4v))
412+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
413+
list(range(4,8)),
414+
test_embedding_4v))
397415
self.assertTrue(torch.equal(overwritten_prompt_embeddings[8:77], default_prompt_embeddings[8:77]))
398416

399417
# at the start
@@ -406,7 +424,9 @@ def test_overwrite_textual_inversion_4v_single(self):
406424
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
407425
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
408426
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
409-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v))
427+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
428+
list(range(1,5)),
429+
test_embedding_4v))
410430
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
411431

412432
# in the middle
@@ -419,7 +439,9 @@ def test_overwrite_textual_inversion_4v_single(self):
419439
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
420440
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
421441
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:2], default_prompt_embeddings[0:2]))
422-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[2:6], test_embedding_4v))
442+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
443+
list(range(2,6)),
444+
test_embedding_4v))
423445
self.assertTrue(torch.equal(overwritten_prompt_embeddings[6:77], default_prompt_embeddings[6:77]))
424446

425447
def test_overwrite_textual_inversion_4v_overflow(self):
@@ -445,8 +467,11 @@ def test_overwrite_textual_inversion_4v_overflow(self):
445467
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
446468
base_prompt_length = len(base_prompt)
447469
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:base_prompt_length+1], default_prompt_embeddings[0:base_prompt_length+1]))
448-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1:base_prompt_length+1+3], test_embedding_4v[0:3]))
449-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+3:77], default_prompt_embeddings[base_prompt_length+1+3:77]))
470+
truncated_overflowed_overwrite_count = min(75 - len(base_prompt), test_embedding_4v.shape[0])
471+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
472+
list(range(base_prompt_length+1,base_prompt_length+1+truncated_overflowed_overwrite_count)),
473+
test_embedding_4v[0:truncated_overflowed_overwrite_count]))
474+
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4:77], default_prompt_embeddings[base_prompt_length+1+4:77]))
450475

451476
# at the start
452477
prompt_token_ids = [test_embedding_4v_token_id] + base_prompt
@@ -459,7 +484,9 @@ def test_overwrite_textual_inversion_4v_overflow(self):
459484
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
460485
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
461486
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
462-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v))
487+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
488+
list(range(1,5)),
489+
test_embedding_4v))
463490
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:77], default_prompt_embeddings[5:77]))
464491

465492
# in the middle
@@ -472,7 +499,9 @@ def test_overwrite_textual_inversion_4v_overflow(self):
472499
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
473500
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
474501
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:21], default_prompt_embeddings[0:21]))
475-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[21:25], test_embedding_4v))
502+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
503+
list(range(21,25)),
504+
test_embedding_4v))
476505
self.assertTrue(torch.equal(overwritten_prompt_embeddings[25:77], default_prompt_embeddings[25:77]))
477506

478507

@@ -504,8 +533,12 @@ def test_overwrite_textual_inversion_4v_multiple(self):
504533
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
505534
base_prompt_length = len(base_prompt)
506535
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:base_prompt_length+1], default_prompt_embeddings[0:base_prompt_length+1]))
507-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1:base_prompt_length+1+4], test_embedding_4v_1))
508-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4:base_prompt_length+1+4+4], test_embedding_4v_2))
536+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
537+
list(range(base_prompt_length+1, base_prompt_length+1+4)),
538+
test_embedding_4v_1))
539+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
540+
list(range(base_prompt_length+1+4, base_prompt_length+1+4+4)),
541+
test_embedding_4v_2))
509542
self.assertTrue(torch.equal(overwritten_prompt_embeddings[base_prompt_length+1+4+4:77], default_prompt_embeddings[base_prompt_length+1+4+4:77]))
510543

511544
# at the start
@@ -519,8 +552,12 @@ def test_overwrite_textual_inversion_4v_multiple(self):
519552
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
520553
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
521554
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:1], default_prompt_embeddings[0:1]))
522-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[1:5], test_embedding_4v_1))
523-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[5:9], test_embedding_4v_2))
555+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
556+
list(range(1,5)),
557+
test_embedding_4v_1))
558+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
559+
list(range(5,9)),
560+
test_embedding_4v_2))
524561
self.assertTrue(torch.equal(overwritten_prompt_embeddings[9:77], default_prompt_embeddings[9:77]))
525562

526563
# in the middle
@@ -533,7 +570,11 @@ def test_overwrite_textual_inversion_4v_multiple(self):
533570
overwritten_prompt_embeddings = tim.overwrite_textual_inversion_embeddings(padded_prompt_token_ids, default_prompt_embeddings)
534571
self.assertFalse(torch.equal(default_prompt_embeddings, overwritten_prompt_embeddings))
535572
self.assertTrue(torch.equal(overwritten_prompt_embeddings[0:11], default_prompt_embeddings[0:11]))
536-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[11:15], test_embedding_4v_1))
573+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
574+
list(range(11,15)),
575+
test_embedding_4v_1))
537576
self.assertTrue(torch.equal(overwritten_prompt_embeddings[15:25], default_prompt_embeddings[15:25]))
538-
self.assertTrue(torch.equal(overwritten_prompt_embeddings[25:29], test_embedding_4v_2))
577+
self.assertTrue(was_embedding_overwritten_correctly(tim, overwritten_prompt_embeddings,
578+
list(range(25,29)),
579+
test_embedding_4v_2))
539580
self.assertTrue(torch.equal(overwritten_prompt_embeddings[29:77], default_prompt_embeddings[29:77]))

0 commit comments

Comments
 (0)