-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Embedding merging #1526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Embedding merging #1526
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
ed3b26f
add whole <style token> to vocab for concept library embeddings
lstein 65c018e
add ability to load multiple concept .bin files
lstein af98ae1
make --log_tokenization respect custom tokens
damian0815 83cf3a2
start working on concept downloading system
lstein e68f7ec
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein f43ba5a
Merge branch 'development' into embedding-merging
lstein 5595b87
preliminary support for dynamic loading and merging of multiple embed…
lstein 6cb31fc
fix loading .pt embeddings; allow multi-vector embeddings; warn on dupes
damian0815 53d99cc
simplify replacement logic and remove cuda assumption
damian0815 f683892
download list of concepts from hugging face
lstein c54131e
remove misleading customization of '*' placeholder
damian0815 1b946e4
Merge branch 'embedding-merging' of github.com:invoke-ai/InvokeAI int…
damian0815 8033790
address all the issues raised by damian0815 in review of PR #1526
lstein 27966fb
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein 16d8d7c
actually resize the token_embeddings
damian0815 5e448e1
Merge branch 'development' into embedding-merging
lstein 87c6b5d
multiple improvements to the concept loader based on code reviews
lstein 60bc394
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein 0ef9f67
Merge branch 'development' into embedding-merging
lstein 0843bbc
autocomplete terms end with ">" now
lstein 8961a73
Merge branch 'embedding-merging' of github.com:/invoke-ai/InvokeAI in…
lstein cedcd95
fix startup error and network unreachable
lstein 49aa405
fix misformatted error string
lstein 0666efe
Merge branch 'development' into embedding-merging
lstein File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| """ | ||
| Query and install embeddings from the HuggingFace SD Concepts Library | ||
| at https://huggingface.co/sd-concepts-library. | ||
|
|
||
| The interface is through the Concepts() object. | ||
| """ | ||
| import os | ||
| import re | ||
| import traceback | ||
| from urllib import request, error as ul_error | ||
| from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi | ||
| from ldm.invoke.globals import Globals | ||
|
|
||
| class Concepts(object): | ||
| def __init__(self, root=None): | ||
| ''' | ||
| Initialize the Concepts object. May optionally pass a root directory. | ||
| ''' | ||
| self.root = root or Globals.root | ||
| self.hf_api = HfApi() | ||
| self.concept_list = None | ||
| self.concepts_loaded = dict() | ||
| self.triggers = dict() # concept name to trigger phrase | ||
| self.concept_names = dict() # trigger phrase to concept name | ||
| self.match_trigger = re.compile('(<[\w\-]+>)') | ||
| self.match_concept = re.compile('<([\w\-]+)>') | ||
|
|
||
| def list_concepts(self)->list: | ||
| ''' | ||
| Return a list of all the concepts by name, without the 'sd-concepts-library' part. | ||
| ''' | ||
| if self.concept_list is not None: | ||
| return self.concept_list | ||
| try: | ||
| models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/')) | ||
| self.concept_list = [a.id.split('/')[1] for a in models] | ||
| except Exception as e: | ||
| print(' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.') | ||
| print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.') | ||
| return self.concept_list | ||
|
|
||
| def get_concept_model_path(self, concept_name:str)->str: | ||
| ''' | ||
| Returns the path to the 'learned_embeds.bin' file in | ||
| the named concept. Returns None if invalid or cannot | ||
| be downloaded. | ||
| ''' | ||
| return self.get_concept_file(concept_name.lower(),'learned_embeds.bin') | ||
|
|
||
| def concept_to_trigger(self, concept_name:str)->str: | ||
| ''' | ||
| Given a concept name returns its trigger by looking in the | ||
| "token_identifier.txt" file. | ||
| ''' | ||
| if concept_name in self.triggers: | ||
| return self.triggers[concept_name] | ||
| file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True) | ||
| if not file: | ||
| return None | ||
| with open(file,'r') as f: | ||
| trigger = f.readline() | ||
| trigger = trigger.strip() | ||
| self.triggers[concept_name] = trigger | ||
| self.concept_names[trigger] = concept_name | ||
| return trigger | ||
|
|
||
| def trigger_to_concept(self, trigger:str)->str: | ||
| ''' | ||
| Given a trigger phrase, maps it to the concept library name. | ||
| Only works if concept_to_trigger() has previously been called | ||
| on this library. There needs to be a persistent database for | ||
| this. | ||
| ''' | ||
| concept = self.concept_names.get(trigger,None) | ||
| return f'<{concept}>' if concept else f'{trigger}' | ||
|
|
||
| def replace_triggers_with_concepts(self, prompt:str)->str: | ||
| ''' | ||
| Given a prompt string that contains <trigger> tags, replace these | ||
| tags with the concept name. The reason for this is so that the | ||
| concept names get stored in the prompt metadata. There is no | ||
| controlling of colliding triggers in the SD library, so it is | ||
| better to store the concept name (unique) than the concept trigger | ||
| (not necessarily unique!) | ||
| ''' | ||
| def do_replace(match)->str: | ||
| return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>' | ||
| return self.match_trigger.sub(do_replace, prompt) | ||
|
|
||
| def replace_concepts_with_triggers(self, prompt:str)->str: | ||
| ''' | ||
| Given a prompt string that contains <concept_name> tags, replace | ||
| these tags with the appropriate trigger. | ||
| ''' | ||
| def do_replace(match)->str: | ||
| return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>' | ||
| return self.match_concept.sub(do_replace, prompt) | ||
|
|
||
| def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str: | ||
| if not self.concept_is_downloaded(concept_name) and not local_only: | ||
| self.download_concept(concept_name) | ||
| path = os.path.join(self._concept_path(concept_name), file_name) | ||
| return path if os.path.exists(path) else None | ||
|
|
||
| def concept_is_downloaded(self, concept_name)->bool: | ||
| concept_directory = self._concept_path(concept_name) | ||
| return os.path.exists(concept_directory) | ||
|
|
||
| def download_concept(self,concept_name)->bool: | ||
| repo_id = self._concept_id(concept_name) | ||
| dest = self._concept_path(concept_name) | ||
|
|
||
| access_token = HfFolder.get_token() | ||
| header = [("Authorization", f'Bearer {access_token}')] if access_token else [] | ||
| opener = request.build_opener() | ||
| opener.addheaders = header | ||
| request.install_opener(opener) | ||
|
|
||
| os.makedirs(dest, exist_ok=True) | ||
| succeeded = True | ||
|
|
||
| bytes = 0 | ||
| def tally_download_size(chunk, size, total): | ||
| nonlocal bytes | ||
| if chunk==0: | ||
| bytes += total | ||
|
|
||
| print(f'>> Downloading {repo_id}...',end='') | ||
| try: | ||
| for file in ('README.md','learned_embeds.bin','token_identifier.txt','type_of_concept.txt'): | ||
| url = hf_hub_url(repo_id, file) | ||
| request.urlretrieve(url, os.path.join(dest,file),reporthook=tally_download_size) | ||
| except ul_error.HTTPError as e: | ||
| if e.code==404: | ||
| print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.') | ||
| else: | ||
| print(f'Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)') | ||
| os.rmdir(dest) | ||
| return False | ||
| except ul_error.URLError as e: | ||
| print(f'ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept.') | ||
| os.rmdir(dest) | ||
| return False | ||
| print('...{:.2f}Kb'.format(bytes/1024)) | ||
| return succeeded | ||
|
|
||
| def _concept_id(self, concept_name:str)->str: | ||
| return f'sd-concepts-library/{concept_name}' | ||
|
|
||
| def _concept_path(self, concept_name:str)->str: | ||
| return os.path.join(self.root,'models','sd-concepts-library',concept_name) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.