1515PROGRESSIVE_SCALE = 2000
1616
1717
18- def get_clip_token_id_for_string (tokenizer : CLIPTokenizer , token_str : str ):
18+ def get_clip_token_id_for_string (tokenizer : CLIPTokenizer , token_str : str ) -> int :
1919 token_id = tokenizer .convert_tokens_to_ids (token_str )
2020 return token_id
2121
22- def get_bert_token_for_string (tokenizer , string ):
22+ def get_bert_token_id_for_string (tokenizer , string ) -> int :
2323 token = tokenizer (string )
2424 # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
25-
2625 token = token [0 , 1 ]
27-
28- return token
26+ return token .item ()
2927
3028
31- def get_embedding_for_clip_token (embedder , token ):
32- return embedder (token .unsqueeze (0 ))[0 , 0 ]
29+ def get_embedding_for_clip_token_id (embedder , token_id ):
30+ if type (token_id ) is not torch .Tensor :
31+ token_id = torch .tensor (token_id , dtype = torch .int )
32+ return embedder (token_id .unsqueeze (0 ))[0 , 0 ]
3333
3434@dataclass
3535class TextualInversion :
@@ -183,9 +183,6 @@ def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], pr
183183 return overwritten_prompt_embeddings
184184
185185
186-
187-
188-
189186class EmbeddingManager (nn .Module ):
190187 def __init__ (
191188 self ,
@@ -222,8 +219,8 @@ def __init__(
222219 get_token_id_for_string = partial (
223220 get_clip_token_id_for_string , embedder .tokenizer
224221 )
225- get_embedding_for_tkn = partial (
226- get_embedding_for_clip_token ,
222+ get_embedding_for_tkn_id = partial (
223+ get_embedding_for_clip_token_id ,
227224 embedder .transformer .text_model .embeddings ,
228225 )
229226 # per bug report #572
@@ -232,9 +229,9 @@ def __init__(
232229 else : # using LDM's BERT encoder
233230 self .is_clip = False
234231 get_token_id_for_string = partial (
235- get_bert_token_for_string , embedder .tknz_fn
232+ get_bert_token_id_for_string , embedder .tknz_fn
236233 )
237- get_embedding_for_tkn = embedder .transformer .token_emb
234+ get_embedding_for_tkn_id = embedder .transformer .token_emb
238235 token_dim = 1280
239236
240237 if per_image_tokens :
@@ -248,9 +245,7 @@ def __init__(
248245 init_word_token_id = get_token_id_for_string (initializer_words [idx ])
249246
250247 with torch .no_grad ():
251- init_word_embedding = get_embedding_for_tkn (
252- init_word_token_id .cpu ()
253- )
248+ init_word_embedding = get_embedding_for_tkn_id (init_word_token_id )
254249
255250 token_params = torch .nn .Parameter (
256251 init_word_embedding .unsqueeze (0 ).repeat (
0 commit comments