From ac4a762f99367ab4dc7380fe9c4043b4b75c0045 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 28 Nov 2022 20:42:14 +0100 Subject: [PATCH 1/2] Fix #1599 by relaxing the `match_trigger` regex Also simplify logic and reduce duplication. --- ldm/invoke/CLI.py | 17 ++--------------- ldm/invoke/concepts_lib.py | 21 +++++++++++++++++---- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index d5687602d9e..6e0f174a34a 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -306,7 +306,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None): if use_prefix is not None: prefix = use_prefix postprocessed = upscaled if upscaled else operation=='postprocess' - opt.prompt = triggers_to_concepts(gen, opt.prompt) # to avoid the problem of non-unique concept triggers + opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt) # to avoid the problem of non-unique concept triggers filename, formatted_dream_prompt = prepare_image_metadata( opt, prefix, @@ -351,7 +351,7 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None): if operation == 'generate': # load any from the SD concepts library - opt.prompt = concepts_to_triggers(gen, opt.prompt) + opt.prompt = gen.concept_lib().replace_concepts_with_triggers(opt.prompt, lambda concepts: gen.load_concepts(concepts)) catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts opt.last_operation='generate' try: @@ -501,19 +501,6 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple: command = '-h' return command, operation -def concepts_to_triggers(gen, prompt:str)->str: - concepts = re.findall('<([^>]+)>',prompt) - if not concepts: - return prompt - gen.load_concepts(concepts) - return gen.concept_lib().replace_concepts_with_triggers(prompt) - -def triggers_to_concepts(gen,prompt:str)->str: - concepts = re.findall('<([^>]+)>',prompt) - if not concepts: - return prompt - return gen.concept_lib().replace_triggers_with_concepts(prompt) - def set_default_output_dir(opt:Args, completer:Completer): ''' If opt.outdir is relative, we add the root directory to it diff --git a/ldm/invoke/concepts_lib.py b/ldm/invoke/concepts_lib.py index 6d1100d8b32..e365dee3138 100644 --- a/ldm/invoke/concepts_lib.py +++ b/ldm/invoke/concepts_lib.py @@ -7,6 +7,7 @@ import os import re import traceback +from typing import Callable from urllib import request, error as ul_error from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi from ldm.invoke.globals import Globals @@ -22,8 +23,8 @@ def __init__(self, root=None): self.concepts_loaded = dict() self.triggers = dict() # concept name to trigger phrase self.concept_names = dict() # trigger phrase to concept name - self.match_trigger = re.compile('(<[\w\-]+>)') - self.match_concept = re.compile('<([\w\-]+)>') + self.match_trigger = re.compile('(<[^>]+>)') # trigger is less restrictive than HF concept name + self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_- def list_concepts(self)->list: ''' @@ -83,15 +84,27 @@ def replace_triggers_with_concepts(self, prompt:str)->str: better to store the concept name (unique) than the concept trigger (not necessarily unique!) ''' + triggers = self.match_trigger.findall(prompt) + if not triggers: + return prompt + def do_replace(match)->str: return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>' return self.match_trigger.sub(do_replace, prompt) - def replace_concepts_with_triggers(self, prompt:str)->str: + def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str: ''' - Given a prompt string that contains tags, replace + Given a prompt string that contains `` tags, replace these tags with the appropriate trigger. + + If any `` tags are found, `load_concepts_callback()` is called with a list + of `concepts_name` strings. ''' + concepts = self.match_concept.findall(prompt) + if not concepts: + return prompt + load_concepts_callback(concepts) + def do_replace(match)->str: return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>' return self.match_concept.sub(do_replace, prompt) From 45cba285541bd72dfc2600b054a43e0d3f645657 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 28 Nov 2022 21:51:15 +0100 Subject: [PATCH 2/2] restrict trigger regex again (but not so far) --- ldm/invoke/concepts_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/concepts_lib.py b/ldm/invoke/concepts_lib.py index e365dee3138..3b6d9e9bad6 100644 --- a/ldm/invoke/concepts_lib.py +++ b/ldm/invoke/concepts_lib.py @@ -23,7 +23,7 @@ def __init__(self, root=None): self.concepts_loaded = dict() self.triggers = dict() # concept name to trigger phrase self.concept_names = dict() # trigger phrase to concept name - self.match_trigger = re.compile('(<[^>]+>)') # trigger is less restrictive than HF concept name + self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_- def list_concepts(self)->list: