Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions ldm/invoke/CLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None):
if use_prefix is not None:
prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess'
opt.prompt = triggers_to_concepts(gen, opt.prompt) # to avoid the problem of non-unique concept triggers
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt) # to avoid the problem of non-unique concept triggers
filename, formatted_dream_prompt = prepare_image_metadata(
opt,
prefix,
Expand Down Expand Up @@ -351,7 +351,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None):

if operation == 'generate':
# load any <embeddings> from the SD concepts library
opt.prompt = concepts_to_triggers(gen, opt.prompt)
opt.prompt = gen.concept_lib().replace_concepts_with_triggers(opt.prompt, lambda concepts: gen.load_concepts(concepts))
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate'
try:
Expand Down Expand Up @@ -501,19 +501,6 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
command = '-h'
return command, operation

def concepts_to_triggers(gen, prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
gen.load_concepts(concepts)
return gen.concept_lib().replace_concepts_with_triggers(prompt)

def triggers_to_concepts(gen,prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
return gen.concept_lib().replace_triggers_with_concepts(prompt)

def set_default_output_dir(opt:Args, completer:Completer):
'''
If opt.outdir is relative, we add the root directory to it
Expand Down
21 changes: 17 additions & 4 deletions ldm/invoke/concepts_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import re
import traceback
from typing import Callable
from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals
Expand All @@ -22,8 +23,8 @@ def __init__(self, root=None):
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile('(<[\w\-]+>)')
self.match_concept = re.compile('<([\w\-]+)>')
self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name
self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_-

def list_concepts(self)->list:
'''
Expand Down Expand Up @@ -83,15 +84,27 @@ def replace_triggers_with_concepts(self, prompt:str)->str:
better to store the concept name (unique) than the concept trigger
(not necessarily unique!)
'''
triggers = self.match_trigger.findall(prompt)
if not triggers:
return prompt

def do_replace(match)->str:
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt)

def replace_concepts_with_triggers(self, prompt:str)->str:
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
'''
Given a prompt string that contains <concept_name> tags, replace
Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger.

If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
of `concepts_name` strings.
'''
concepts = self.match_concept.findall(prompt)
if not concepts:
return prompt
load_concepts_callback(concepts)

def do_replace(match)->str:
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt)
Expand Down