11
22import unittest
3+ from typing import Union
34
45import torch
56
6- from ldm .modules .embedding_manager import TextualInversionManager
7+ from ldm .modules .textual_inversion_manager import TextualInversionManager
78
89
910KNOWN_WORDS = ['a' , 'b' , 'c' ]
@@ -53,7 +54,16 @@ def __init__(self):
5354 self .max_length = 77
5455 self .transformer = DummyTransformer ()
5556 self .tokenizer = DummyTokenizer ()
57+ self .position_embeddings_tensor = torch .randn ([77 ,768 ], dtype = torch .float32 )
5658
59+ def position_embedding (self , indices : Union [list ,torch .Tensor ]):
60+ if type (indices ) is list :
61+ indices = torch .tensor (indices , dtype = int )
62+ return torch .index_select (self .position_embeddings_tensor , 0 , indices )
63+
64+
65+ def was_embedding_overwritten_correctly (tim : TextualInversionManager , overwritten_embedding : torch .Tensor , ti_indices : list , ti_embedding : torch .Tensor ) -> bool :
66+ return torch .allclose (overwritten_embedding [ti_indices ], ti_embedding + tim .clip_embedder .position_embedding (ti_indices ))
5767
5868class TextualInversionManagerTestCase (unittest .TestCase ):
5969
@@ -270,7 +280,7 @@ def test_overwrite_textual_inversion_1v_single(self):
270280 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
271281 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
272282 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :4 ], default_prompt_embeddings [0 :4 ]))
273- self .assertTrue (torch . equal ( overwritten_prompt_embeddings [4 ], test_embedding_1v [ 0 ] ))
283+ self .assertTrue (was_embedding_overwritten_correctly ( tim , overwritten_prompt_embeddings , [4 ], test_embedding_1v ))
274284 self .assertTrue (torch .equal (overwritten_prompt_embeddings [5 :77 ], default_prompt_embeddings [5 :77 ]))
275285
276286 # at the start
@@ -283,7 +293,7 @@ def test_overwrite_textual_inversion_1v_single(self):
283293 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
284294 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
285295 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :1 ], default_prompt_embeddings [0 :1 ]))
286- self .assertTrue (torch . equal ( overwritten_prompt_embeddings [1 ], test_embedding_1v [ 0 ] ))
296+ self .assertTrue (was_embedding_overwritten_correctly ( tim , overwritten_prompt_embeddings , [1 ], test_embedding_1v ))
287297 self .assertTrue (torch .equal (overwritten_prompt_embeddings [2 :77 ], default_prompt_embeddings [2 :77 ]))
288298
289299 # in the middle
@@ -296,7 +306,7 @@ def test_overwrite_textual_inversion_1v_single(self):
296306 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
297307 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
298308 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :2 ], default_prompt_embeddings [0 :2 ]))
299- self .assertTrue (torch . equal ( overwritten_prompt_embeddings [2 ], test_embedding_1v [ 0 ] ))
309+ self .assertTrue (was_embedding_overwritten_correctly ( tim , overwritten_prompt_embeddings , [2 ], test_embedding_1v ))
300310 self .assertTrue (torch .equal (overwritten_prompt_embeddings [3 :77 ], default_prompt_embeddings [3 :77 ]))
301311
302312
@@ -326,8 +336,8 @@ def test_overwrite_textual_inversion_1v_multiple(self):
326336 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
327337 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
328338 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :4 ], default_prompt_embeddings [0 :4 ]))
329- self .assertTrue (torch . equal ( overwritten_prompt_embeddings [4 ], test_embedding_1v_1 [ 0 ] ))
330- self .assertTrue (torch . equal ( overwritten_prompt_embeddings [5 ], test_embedding_1v_2 [ 0 ] ))
339+ self .assertTrue (was_embedding_overwritten_correctly ( tim , overwritten_prompt_embeddings , [4 ], test_embedding_1v_1 ))
340+ self .assertTrue (was_embedding_overwritten_correctly ( tim , overwritten_prompt_embeddings , [5 ], test_embedding_1v_2 ))
331341 self .assertTrue (torch .equal (overwritten_prompt_embeddings [6 :77 ], default_prompt_embeddings [6 :77 ]))
332342
333343 # at the start
@@ -340,8 +350,10 @@ def test_overwrite_textual_inversion_1v_multiple(self):
340350 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
341351 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
342352 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :1 ], default_prompt_embeddings [0 :1 ]))
343- self .assertTrue (torch .equal (overwritten_prompt_embeddings [1 ], test_embedding_1v_1 [0 ]))
344- self .assertTrue (torch .equal (overwritten_prompt_embeddings [2 ], test_embedding_1v_2 [0 ]))
353+ self .assertTrue (
354+ was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings , [1 ], test_embedding_1v_1 ))
355+ self .assertTrue (
356+ was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings , [2 ], test_embedding_1v_2 ))
345357 self .assertTrue (torch .equal (overwritten_prompt_embeddings [3 :77 ], default_prompt_embeddings [3 :77 ]))
346358
347359 # clumped in the middle
@@ -354,8 +366,10 @@ def test_overwrite_textual_inversion_1v_multiple(self):
354366 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
355367 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
356368 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :2 ], default_prompt_embeddings [0 :2 ]))
357- self .assertTrue (torch .equal (overwritten_prompt_embeddings [2 ], test_embedding_1v_1 [0 ]))
358- self .assertTrue (torch .equal (overwritten_prompt_embeddings [3 ], test_embedding_1v_2 [0 ]))
369+ self .assertTrue (
370+ was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings , [2 ], test_embedding_1v_1 ))
371+ self .assertTrue (
372+ was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings , [3 ], test_embedding_1v_2 ))
359373 self .assertTrue (torch .equal (overwritten_prompt_embeddings [4 :77 ], default_prompt_embeddings [4 :77 ]))
360374
361375 # scattered
@@ -368,9 +382,11 @@ def test_overwrite_textual_inversion_1v_multiple(self):
368382 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
369383 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
370384 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :2 ], default_prompt_embeddings [0 :2 ]))
371- self .assertTrue (torch .equal (overwritten_prompt_embeddings [2 ], test_embedding_1v_1 [0 ]))
385+ self .assertTrue (
386+ was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings , [2 ], test_embedding_1v_1 ))
372387 self .assertTrue (torch .equal (overwritten_prompt_embeddings [3 ], default_prompt_embeddings [3 ]))
373- self .assertTrue (torch .equal (overwritten_prompt_embeddings [4 ], test_embedding_1v_2 [0 ]))
388+ self .assertTrue (
389+ was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings , [4 ], test_embedding_1v_2 ))
374390 self .assertTrue (torch .equal (overwritten_prompt_embeddings [5 :77 ], default_prompt_embeddings [5 :77 ]))
375391
376392 def test_overwrite_textual_inversion_4v_single (self ):
@@ -393,7 +409,9 @@ def test_overwrite_textual_inversion_4v_single(self):
393409 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
394410 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
395411 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :4 ], default_prompt_embeddings [0 :4 ]))
396- self .assertTrue (torch .equal (overwritten_prompt_embeddings [4 :8 ], test_embedding_4v ))
412+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
413+ list (range (4 ,8 )),
414+ test_embedding_4v ))
397415 self .assertTrue (torch .equal (overwritten_prompt_embeddings [8 :77 ], default_prompt_embeddings [8 :77 ]))
398416
399417 # at the start
@@ -406,7 +424,9 @@ def test_overwrite_textual_inversion_4v_single(self):
406424 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
407425 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
408426 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :1 ], default_prompt_embeddings [0 :1 ]))
409- self .assertTrue (torch .equal (overwritten_prompt_embeddings [1 :5 ], test_embedding_4v ))
427+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
428+ list (range (1 ,5 )),
429+ test_embedding_4v ))
410430 self .assertTrue (torch .equal (overwritten_prompt_embeddings [5 :77 ], default_prompt_embeddings [5 :77 ]))
411431
412432 # in the middle
@@ -419,7 +439,9 @@ def test_overwrite_textual_inversion_4v_single(self):
419439 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
420440 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
421441 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :2 ], default_prompt_embeddings [0 :2 ]))
422- self .assertTrue (torch .equal (overwritten_prompt_embeddings [2 :6 ], test_embedding_4v ))
442+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
443+ list (range (2 ,6 )),
444+ test_embedding_4v ))
423445 self .assertTrue (torch .equal (overwritten_prompt_embeddings [6 :77 ], default_prompt_embeddings [6 :77 ]))
424446
425447 def test_overwrite_textual_inversion_4v_overflow (self ):
@@ -445,8 +467,11 @@ def test_overwrite_textual_inversion_4v_overflow(self):
445467 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
446468 base_prompt_length = len (base_prompt )
447469 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :base_prompt_length + 1 ], default_prompt_embeddings [0 :base_prompt_length + 1 ]))
448- self .assertTrue (torch .equal (overwritten_prompt_embeddings [base_prompt_length + 1 :base_prompt_length + 1 + 3 ], test_embedding_4v [0 :3 ]))
449- self .assertTrue (torch .equal (overwritten_prompt_embeddings [base_prompt_length + 1 + 3 :77 ], default_prompt_embeddings [base_prompt_length + 1 + 3 :77 ]))
470+ truncated_overflowed_overwrite_count = min (75 - len (base_prompt ), test_embedding_4v .shape [0 ])
471+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
472+ list (range (base_prompt_length + 1 ,base_prompt_length + 1 + truncated_overflowed_overwrite_count )),
473+ test_embedding_4v [0 :truncated_overflowed_overwrite_count ]))
474+ self .assertTrue (torch .equal (overwritten_prompt_embeddings [base_prompt_length + 1 + 4 :77 ], default_prompt_embeddings [base_prompt_length + 1 + 4 :77 ]))
450475
451476 # at the start
452477 prompt_token_ids = [test_embedding_4v_token_id ] + base_prompt
@@ -459,7 +484,9 @@ def test_overwrite_textual_inversion_4v_overflow(self):
459484 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
460485 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
461486 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :1 ], default_prompt_embeddings [0 :1 ]))
462- self .assertTrue (torch .equal (overwritten_prompt_embeddings [1 :5 ], test_embedding_4v ))
487+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
488+ list (range (1 ,5 )),
489+ test_embedding_4v ))
463490 self .assertTrue (torch .equal (overwritten_prompt_embeddings [5 :77 ], default_prompt_embeddings [5 :77 ]))
464491
465492 # in the middle
@@ -472,7 +499,9 @@ def test_overwrite_textual_inversion_4v_overflow(self):
472499 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
473500 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
474501 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :21 ], default_prompt_embeddings [0 :21 ]))
475- self .assertTrue (torch .equal (overwritten_prompt_embeddings [21 :25 ], test_embedding_4v ))
502+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
503+ list (range (21 ,25 )),
504+ test_embedding_4v ))
476505 self .assertTrue (torch .equal (overwritten_prompt_embeddings [25 :77 ], default_prompt_embeddings [25 :77 ]))
477506
478507
@@ -504,8 +533,12 @@ def test_overwrite_textual_inversion_4v_multiple(self):
504533 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
505534 base_prompt_length = len (base_prompt )
506535 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :base_prompt_length + 1 ], default_prompt_embeddings [0 :base_prompt_length + 1 ]))
507- self .assertTrue (torch .equal (overwritten_prompt_embeddings [base_prompt_length + 1 :base_prompt_length + 1 + 4 ], test_embedding_4v_1 ))
508- self .assertTrue (torch .equal (overwritten_prompt_embeddings [base_prompt_length + 1 + 4 :base_prompt_length + 1 + 4 + 4 ], test_embedding_4v_2 ))
536+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
537+ list (range (base_prompt_length + 1 , base_prompt_length + 1 + 4 )),
538+ test_embedding_4v_1 ))
539+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
540+ list (range (base_prompt_length + 1 + 4 , base_prompt_length + 1 + 4 + 4 )),
541+ test_embedding_4v_2 ))
509542 self .assertTrue (torch .equal (overwritten_prompt_embeddings [base_prompt_length + 1 + 4 + 4 :77 ], default_prompt_embeddings [base_prompt_length + 1 + 4 + 4 :77 ]))
510543
511544 # at the start
@@ -519,8 +552,12 @@ def test_overwrite_textual_inversion_4v_multiple(self):
519552 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
520553 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
521554 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :1 ], default_prompt_embeddings [0 :1 ]))
522- self .assertTrue (torch .equal (overwritten_prompt_embeddings [1 :5 ], test_embedding_4v_1 ))
523- self .assertTrue (torch .equal (overwritten_prompt_embeddings [5 :9 ], test_embedding_4v_2 ))
555+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
556+ list (range (1 ,5 )),
557+ test_embedding_4v_1 ))
558+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
559+ list (range (5 ,9 )),
560+ test_embedding_4v_2 ))
524561 self .assertTrue (torch .equal (overwritten_prompt_embeddings [9 :77 ], default_prompt_embeddings [9 :77 ]))
525562
526563 # in the middle
@@ -533,7 +570,11 @@ def test_overwrite_textual_inversion_4v_multiple(self):
533570 overwritten_prompt_embeddings = tim .overwrite_textual_inversion_embeddings (padded_prompt_token_ids , default_prompt_embeddings )
534571 self .assertFalse (torch .equal (default_prompt_embeddings , overwritten_prompt_embeddings ))
535572 self .assertTrue (torch .equal (overwritten_prompt_embeddings [0 :11 ], default_prompt_embeddings [0 :11 ]))
536- self .assertTrue (torch .equal (overwritten_prompt_embeddings [11 :15 ], test_embedding_4v_1 ))
573+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
574+ list (range (11 ,15 )),
575+ test_embedding_4v_1 ))
537576 self .assertTrue (torch .equal (overwritten_prompt_embeddings [15 :25 ], default_prompt_embeddings [15 :25 ]))
538- self .assertTrue (torch .equal (overwritten_prompt_embeddings [25 :29 ], test_embedding_4v_2 ))
577+ self .assertTrue (was_embedding_overwritten_correctly (tim , overwritten_prompt_embeddings ,
578+ list (range (25 ,29 )),
579+ test_embedding_4v_2 ))
539580 self .assertTrue (torch .equal (overwritten_prompt_embeddings [29 :77 ], default_prompt_embeddings [29 :77 ]))
0 commit comments