Skip to content

Commit 664a6e9

Browse files
committed
use TextualInversionManager in place of embeddings (wip, doesn't work)
1 parent 023df37 commit 664a6e9

File tree

9 files changed

+291
-229
lines changed

9 files changed

+291
-229
lines changed

ldm/generate.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from omegaconf import OmegaConf
2323

2424
import ldm.invoke.conditioning
25+
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
2526
from ldm.invoke.generator.base import downsampling
2627
from PIL import Image, ImageOps
2728
from torch import nn
@@ -41,7 +42,6 @@
4142
from ldm.invoke.model_cache import ModelCache
4243
from ldm.invoke.seamless import configure_model_padding
4344
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
44-
from ldm.invoke.concepts_lib import Concepts
4545

4646
def fix_func(orig):
4747
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
@@ -438,7 +438,7 @@ def process_image(image,seed):
438438
self._set_sampler()
439439

440440
# apply the concepts library to the prompt
441-
prompt = self.concept_lib().replace_concepts_with_triggers(prompt, lambda concepts: self.load_concepts(concepts))
441+
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
442442

443443
# bit of a hack to change the cached sampler's karras threshold to
444444
# whatever the user asked for
@@ -862,19 +862,22 @@ def set_model(self,model_name):
862862

863863
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
864864
if self.embedding_path is not None:
865-
self.model.embedding_manager.load(
866-
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
867-
)
865+
for root, _, files in os.walk(self.embedding_path):
866+
for name in files:
867+
ti_path = os.path.join(root, name)
868+
self.model.textual_inversion_manager.load_textual_inversion(ti_path)
869+
print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}')
868870

869871
self._set_sampler()
870872
self.model_name = model_name
871873
return self.model
872874

873-
def load_concepts(self,concepts:list[str]):
874-
self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast')
875+
def load_huggingface_concepts(self, concepts:list[str]):
876+
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
875877

876-
def concept_lib(self)->Concepts:
877-
return self.model.embedding_manager.concepts_library
878+
@property
879+
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
880+
return self.model.textual_inversion_manager.hf_concepts_library
878881

879882
def correct_colors(self,
880883
image_list,

ldm/invoke/CLI.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
1717
from ldm.invoke.image_util import make_grid
1818
from ldm.invoke.log import write_log
19-
from ldm.invoke.concepts_lib import Concepts
19+
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
2020
from omegaconf import OmegaConf
2121
from pathlib import Path
2222
import pyparsing
@@ -133,6 +133,10 @@ def main():
133133
main_loop(gen, opt)
134134
except KeyboardInterrupt:
135135
print("\ngoodbye!")
136+
except Exception:
137+
print(">> An error occurred:")
138+
traceback.print_exc()
139+
136140

137141
# TODO: main_loop() has gotten busy. Needs to be refactored.
138142
def main_loop(gen, opt):
@@ -310,7 +314,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None,
310314
if use_prefix is not None:
311315
prefix = use_prefix
312316
postprocessed = upscaled if upscaled else operation=='postprocess'
313-
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
317+
opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
314318
filename, formatted_dream_prompt = prepare_image_metadata(
315319
opt,
316320
prefix,
@@ -809,7 +813,8 @@ def add_embedding_terms(gen,completer):
809813
Called after setting the model, updates the autocompleter with
810814
any terms loaded by the embedding manager.
811815
'''
812-
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
816+
trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings()
817+
completer.add_embedding_terms(trigger_strings)
813818

814819
def split_variations(variations_string) -> list:
815820
# shotgun parsing, woo

ldm/invoke/concepts_lib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
1313
from ldm.invoke.globals import Globals
1414

15-
class Concepts(object):
15+
class HuggingFaceConceptsLibrary(object):
1616
def __init__(self, root=None):
1717
'''
1818
Initialize the Concepts object. May optionally pass a root directory.
@@ -116,11 +116,11 @@ def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin'
116116
self.download_concept(concept_name)
117117
path = os.path.join(self._concept_path(concept_name), file_name)
118118
return path if os.path.exists(path) else None
119-
119+
120120
def concept_is_downloaded(self, concept_name)->bool:
121121
concept_directory = self._concept_path(concept_name)
122122
return os.path.exists(concept_directory)
123-
123+
124124
def download_concept(self,concept_name)->bool:
125125
repo_id = self._concept_id(concept_name)
126126
dest = self._concept_path(concept_name)
@@ -133,7 +133,7 @@ def download_concept(self,concept_name)->bool:
133133

134134
os.makedirs(dest, exist_ok=True)
135135
succeeded = True
136-
136+
137137
bytes = 0
138138
def tally_download_size(chunk, size, total):
139139
nonlocal bytes

ldm/invoke/conditioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
231231

232232
def _get_tokens_length(model, fragments: list[Fragment]):
233233
fragment_texts = [x.text for x in fragments]
234-
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
234+
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
235235
return sum([len(x) for x in tokens])
236236

237237

ldm/invoke/readline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313
import atexit
1414
from ldm.invoke.args import Args
15-
from ldm.invoke.concepts_lib import Concepts
15+
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
1616
from ldm.invoke.globals import Globals
1717

1818
# ---------------readline utilities---------------------
@@ -276,7 +276,7 @@ def add_embedding_terms(self, terms:list[str]):
276276

277277
def _concept_completions(self, text, state):
278278
if self.concepts is None:
279-
self.concepts = set(Concepts().list_concepts())
279+
self.concepts = set(HuggingFaceConceptsLibrary().list_concepts())
280280
self.embedding_terms.update(self.concepts)
281281

282282
partial = text[1:] # this removes the leading '<'

ldm/models/diffusion/ddpm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from omegaconf import ListConfig
2323
import urllib
2424

25+
from ldm.modules.textual_inversion_manager import TextualInversionManager
2526
from ldm.util import (
2627
log_txt_as_img,
2728
exists,
@@ -678,6 +679,9 @@ def __init__(
678679
self.embedding_manager = self.instantiate_embedding_manager(
679680
personalization_config, self.cond_stage_model
680681
)
682+
self.textual_inversion_manager = TextualInversionManager(self.cond_stage_model, full_precision=True)
683+
# this circular component dependency is gross and bad, needs to be rethought
684+
self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager)
681685

682686
self.emb_ckpt_counter = 0
683687

ldm/modules/embedding_manager.py

Lines changed: 2 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import sys
88

9-
from ldm.invoke.concepts_lib import Concepts
9+
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
1010
from ldm.data.personalized import per_img_token_list
1111
from transformers import CLIPTokenizer
1212
from functools import partial
@@ -31,157 +31,6 @@ def get_embedding_for_clip_token_id(embedder, token_id):
3131
token_id = torch.tensor(token_id, dtype=torch.int)
3232
return embedder(token_id.unsqueeze(0))[0, 0]
3333

34-
@dataclass
35-
class TextualInversion:
36-
trigger_string: str
37-
token_id: int
38-
embedding: torch.Tensor
39-
40-
@property
41-
def embedding_vector_length(self) -> int:
42-
return self.embedding.shape[0]
43-
44-
class TextualInversionManager():
45-
def __init__(self, clip_embedder):
46-
self.clip_embedder = clip_embedder
47-
default_textual_inversions: list[TextualInversion] = []
48-
self.textual_inversions = default_textual_inversions
49-
50-
def load_textual_inversion(self, ckpt_path, full_precision=True):
51-
52-
scan_result = scan_file_path(ckpt_path)
53-
if scan_result.infected_files == 1:
54-
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
55-
print('### For your safety, InvokeAI will not load this embed.')
56-
return
57-
58-
ckpt = torch.load(ckpt_path, map_location='cpu')
59-
60-
# Handle .pt textual inversion files
61-
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
62-
filename = os.path.basename(ckpt_path)
63-
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
64-
if len(ckpt["string_to_token"]) > 1:
65-
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
66-
67-
string_to_param_dict = ckpt['string_to_param']
68-
embedding = list(string_to_param_dict.values())[0]
69-
self.add_textual_inversion(token_str, embedding, full_precision)
70-
71-
# Handle .bin textual inversion files from Huggingface Concepts
72-
# https://huggingface.co/sd-concepts-library
73-
else:
74-
for token_str in list(ckpt.keys()):
75-
embedding = ckpt[token_str]
76-
self.add_textual_inversion(token_str, embedding, full_precision)
77-
78-
def add_textual_inversion(self, token_str, embedding) -> int:
79-
"""
80-
Add a textual inversion to be recognised.
81-
:param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
82-
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
83-
:return: The token id for the added embedding, either existing or newly-added.
84-
"""
85-
if token_str in [ti.trigger_string for ti in self.textual_inversions]:
86-
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'")
87-
return
88-
if len(embedding.shape) == 1:
89-
embedding = embedding.unsqueeze(0)
90-
elif len(embedding.shape) > 2:
91-
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
92-
93-
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
94-
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
95-
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
96-
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
97-
current_token_count = current_embeddings.num_embeddings
98-
new_token_count = current_token_count + num_tokens_added
99-
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
100-
101-
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
102-
self.textual_inversions.append(TextualInversion(
103-
trigger_string=token_str,
104-
token_id=token_id,
105-
embedding=embedding
106-
))
107-
return token_id
108-
109-
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
110-
try:
111-
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
112-
return ti is not None
113-
except StopIteration:
114-
return False
115-
116-
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
117-
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
118-
119-
120-
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
121-
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
122-
123-
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
124-
"""
125-
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
126-
127-
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
128-
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
129-
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
130-
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
131-
"""
132-
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
133-
raise ValueError("prompt_token_ids must not start with bos_token_id")
134-
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
135-
raise ValueError("prompt_token_ids must not end with eos_token_id")
136-
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
137-
prompt_token_ids = prompt_token_ids[:]
138-
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
139-
if token_id in textual_inversion_token_ids:
140-
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
141-
for pad_idx in range(1, textual_inversion.embedding_vector_length):
142-
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
143-
144-
return prompt_token_ids
145-
146-
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor:
147-
"""
148-
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
149-
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
150-
subsequent rows in `prompt_embeddings` as well.
151-
152-
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
153-
>1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
154-
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
155-
`prompt_token_ids` (i.e., also already expanded).
156-
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
157-
"""
158-
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
159-
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
160-
if len(prompt_token_ids) > self.clip_embedder.max_length:
161-
raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
162-
if len(prompt_token_ids) < self.clip_embedder.max_length:
163-
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
164-
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
165-
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
166-
167-
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
168-
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
169-
overwritten_prompt_embeddings = prompt_embeddings.clone()
170-
for i, token_id in enumerate(prompt_token_ids):
171-
if token_id == pad_token_id:
172-
continue
173-
if token_id in textual_inversion_token_ids:
174-
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
175-
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
176-
count_to_overwrite = end_index - i
177-
for j in range(0, count_to_overwrite):
178-
# only overwrite the textual inversion token id or the padding token id
179-
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
180-
break
181-
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
182-
183-
return overwritten_prompt_embeddings
184-
18534

18635
class EmbeddingManager(nn.Module):
18736
def __init__(
@@ -197,8 +46,7 @@ def __init__(
19746
super().__init__()
19847

19948
self.embedder = embedder
200-
self.concepts_library=Concepts()
201-
self.concepts_loaded = dict()
49+
self.concepts_library=HuggingFaceConceptsLibrary()
20250

20351
self.string_to_token_dict = {}
20452
self.string_to_param_dict = nn.ParameterDict()
@@ -349,22 +197,6 @@ def save(self, ckpt_path):
349197
ckpt_path,
350198
)
351199

352-
def load_concepts(self, concepts:list[str], full=True):
353-
bin_files = list()
354-
for concept_name in concepts:
355-
if concept_name in self.concepts_loaded:
356-
continue
357-
else:
358-
bin_file = self.concepts_library.get_concept_model_path(concept_name)
359-
if not bin_file:
360-
continue
361-
bin_files.append(bin_file)
362-
self.concepts_loaded[concept_name]=True
363-
self.load(bin_files, full)
364-
365-
def list_terms(self) -> list[str]:
366-
return self.concepts_loaded.keys()
367-
368200
def load(self, ckpt_paths, full=True):
369201
if len(ckpt_paths) == 0:
370202
return

0 commit comments

Comments
 (0)