66
77import sys
88
9- from ldm .invoke .concepts_lib import Concepts
9+ from ldm .invoke .concepts_lib import HuggingFaceConceptsLibrary
1010from ldm .data .personalized import per_img_token_list
1111from transformers import CLIPTokenizer
1212from functools import partial
@@ -31,157 +31,6 @@ def get_embedding_for_clip_token_id(embedder, token_id):
3131 token_id = torch .tensor (token_id , dtype = torch .int )
3232 return embedder (token_id .unsqueeze (0 ))[0 , 0 ]
3333
34- @dataclass
35- class TextualInversion :
36- trigger_string : str
37- token_id : int
38- embedding : torch .Tensor
39-
40- @property
41- def embedding_vector_length (self ) -> int :
42- return self .embedding .shape [0 ]
43-
44- class TextualInversionManager ():
45- def __init__ (self , clip_embedder ):
46- self .clip_embedder = clip_embedder
47- default_textual_inversions : list [TextualInversion ] = []
48- self .textual_inversions = default_textual_inversions
49-
50- def load_textual_inversion (self , ckpt_path , full_precision = True ):
51-
52- scan_result = scan_file_path (ckpt_path )
53- if scan_result .infected_files == 1 :
54- print (f'\n ### Security Issues Found in Model: { scan_result .issues_count } ' )
55- print ('### For your safety, InvokeAI will not load this embed.' )
56- return
57-
58- ckpt = torch .load (ckpt_path , map_location = 'cpu' )
59-
60- # Handle .pt textual inversion files
61- if 'string_to_token' in ckpt and 'string_to_param' in ckpt :
62- filename = os .path .basename (ckpt_path )
63- token_str = '.' .join (filename .split ('.' )[:- 1 ]) # filename excluding extension
64- if len (ckpt ["string_to_token" ]) > 1 :
65- print (f">> { ckpt_path } has >1 embedding, only the first will be used" )
66-
67- string_to_param_dict = ckpt ['string_to_param' ]
68- embedding = list (string_to_param_dict .values ())[0 ]
69- self .add_textual_inversion (token_str , embedding , full_precision )
70-
71- # Handle .bin textual inversion files from Huggingface Concepts
72- # https://huggingface.co/sd-concepts-library
73- else :
74- for token_str in list (ckpt .keys ()):
75- embedding = ckpt [token_str ]
76- self .add_textual_inversion (token_str , embedding , full_precision )
77-
78- def add_textual_inversion (self , token_str , embedding ) -> int :
79- """
80- Add a textual inversion to be recognised.
81- :param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
82- :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
83- :return: The token id for the added embedding, either existing or newly-added.
84- """
85- if token_str in [ti .trigger_string for ti in self .textual_inversions ]:
86- print (f">> TextualInversionManager refusing to overwrite already-loaded token '{ token_str } '" )
87- return
88- if len (embedding .shape ) == 1 :
89- embedding = embedding .unsqueeze (0 )
90- elif len (embedding .shape ) > 2 :
91- raise ValueError (f"embedding shape { embedding .shape } is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2" )
92-
93- existing_token_id = get_clip_token_id_for_string (self .clip_embedder .tokenizer , token_str )
94- if existing_token_id == self .clip_embedder .tokenizer .unk_token_id :
95- num_tokens_added = self .clip_embedder .tokenizer .add_tokens (token_str )
96- current_embeddings = self .clip_embedder .transformer .resize_token_embeddings (None )
97- current_token_count = current_embeddings .num_embeddings
98- new_token_count = current_token_count + num_tokens_added
99- self .clip_embedder .transformer .resize_token_embeddings (new_token_count )
100-
101- token_id = get_clip_token_id_for_string (self .clip_embedder .tokenizer , token_str )
102- self .textual_inversions .append (TextualInversion (
103- trigger_string = token_str ,
104- token_id = token_id ,
105- embedding = embedding
106- ))
107- return token_id
108-
109- def has_textual_inversion_for_trigger_string (self , trigger_string : str ) -> bool :
110- try :
111- ti = self .get_textual_inversion_for_trigger_string (trigger_string )
112- return ti is not None
113- except StopIteration :
114- return False
115-
116- def get_textual_inversion_for_trigger_string (self , trigger_string : str ) -> TextualInversion :
117- return next (ti for ti in self .textual_inversions if ti .trigger_string == trigger_string )
118-
119-
120- def get_textual_inversion_for_token_id (self , token_id : int ) -> TextualInversion :
121- return next (ti for ti in self .textual_inversions if ti .token_id == token_id )
122-
123- def expand_textual_inversion_token_ids (self , prompt_token_ids : list [int ]) -> list [int ]:
124- """
125- Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
126-
127- :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
128- :param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
129- :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
130- long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
131- """
132- if prompt_token_ids [0 ] == self .clip_embedder .tokenizer .bos_token_id :
133- raise ValueError ("prompt_token_ids must not start with bos_token_id" )
134- if prompt_token_ids [- 1 ] == self .clip_embedder .tokenizer .eos_token_id :
135- raise ValueError ("prompt_token_ids must not end with eos_token_id" )
136- textual_inversion_token_ids = [ti .token_id for ti in self .textual_inversions ]
137- prompt_token_ids = prompt_token_ids [:]
138- for i , token_id in reversed (list (enumerate (prompt_token_ids ))):
139- if token_id in textual_inversion_token_ids :
140- textual_inversion = next (ti for ti in self .textual_inversions if ti .token_id == token_id )
141- for pad_idx in range (1 , textual_inversion .embedding_vector_length ):
142- prompt_token_ids .insert (i + 1 , self .clip_embedder .tokenizer .pad_token_id )
143-
144- return prompt_token_ids
145-
146- def overwrite_textual_inversion_embeddings (self , prompt_token_ids : list [int ], prompt_embeddings : torch .Tensor ) -> torch .Tensor :
147- """
148- For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
149- row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
150- subsequent rows in `prompt_embeddings` as well.
151-
152- :param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
153- >1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
154- :param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
155- `prompt_token_ids` (i.e., also already expanded).
156- :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
157- """
158- if prompt_embeddings .shape [0 ] != self .clip_embedder .max_length : # typically 77
159- raise ValueError (f"prompt_embeddings must have { self .clip_embedder .max_length } entries (has: { prompt_embeddings .shape [0 ]} )" )
160- if len (prompt_token_ids ) > self .clip_embedder .max_length :
161- raise ValueError (f"prompt_token_ids is too long (has { len (prompt_token_ids )} token ids, should have { self .clip_embedder .max_length } )" )
162- if len (prompt_token_ids ) < self .clip_embedder .max_length :
163- raise ValueError (f"prompt_token_ids is too short (has { len (prompt_token_ids )} token ids, it must be fully padded out to { self .clip_embedder .max_length } entries)" )
164- if prompt_token_ids [0 ] != self .clip_embedder .tokenizer .bos_token_id or prompt_token_ids [- 1 ] != self .clip_embedder .tokenizer .eos_token_id :
165- raise ValueError ("prompt_token_ids must start with with bos token id and end with the eos token id" )
166-
167- textual_inversion_token_ids = [ti .token_id for ti in self .textual_inversions ]
168- pad_token_id = self .clip_embedder .tokenizer .pad_token_id
169- overwritten_prompt_embeddings = prompt_embeddings .clone ()
170- for i , token_id in enumerate (prompt_token_ids ):
171- if token_id == pad_token_id :
172- continue
173- if token_id in textual_inversion_token_ids :
174- textual_inversion = next (ti for ti in self .textual_inversions if ti .token_id == token_id )
175- end_index = min (i + textual_inversion .embedding_vector_length , self .clip_embedder .max_length - 1 )
176- count_to_overwrite = end_index - i
177- for j in range (0 , count_to_overwrite ):
178- # only overwrite the textual inversion token id or the padding token id
179- if prompt_token_ids [i + j ] != pad_token_id and prompt_token_ids [i + j ] != token_id :
180- break
181- overwritten_prompt_embeddings [i + j ] = textual_inversion .embedding [j ]
182-
183- return overwritten_prompt_embeddings
184-
18534
18635class EmbeddingManager (nn .Module ):
18736 def __init__ (
@@ -197,8 +46,7 @@ def __init__(
19746 super ().__init__ ()
19847
19948 self .embedder = embedder
200- self .concepts_library = Concepts ()
201- self .concepts_loaded = dict ()
49+ self .concepts_library = HuggingFaceConceptsLibrary ()
20250
20351 self .string_to_token_dict = {}
20452 self .string_to_param_dict = nn .ParameterDict ()
@@ -349,22 +197,6 @@ def save(self, ckpt_path):
349197 ckpt_path ,
350198 )
351199
352- def load_concepts (self , concepts :list [str ], full = True ):
353- bin_files = list ()
354- for concept_name in concepts :
355- if concept_name in self .concepts_loaded :
356- continue
357- else :
358- bin_file = self .concepts_library .get_concept_model_path (concept_name )
359- if not bin_file :
360- continue
361- bin_files .append (bin_file )
362- self .concepts_loaded [concept_name ]= True
363- self .load (bin_files , full )
364-
365- def list_terms (self ) -> list [str ]:
366- return self .concepts_loaded .keys ()
367-
368200 def load (self , ckpt_paths , full = True ):
369201 if len (ckpt_paths ) == 0 :
370202 return
0 commit comments