Skip to content
Merged
38 changes: 29 additions & 9 deletions backend/invoke_ai_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from uuid import uuid4
from threading import Event

from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.prompt_parser import split_weighted_subprompts
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
from ldm.invoke.generator.inpaint import infill_methods

from backend.modules.parameters import parameters_to_command
Expand All @@ -39,7 +41,7 @@


class InvokeAIWebServer:
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
self.host = args.host
self.port = args.port

Expand Down Expand Up @@ -905,16 +907,13 @@ def image_progress(sample, step):
},
)


if generation_parameters["progress_latents"]:
image = self.generate.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
img_base64 = image_to_dataURL(image)
self.socketio.emit(
"intermediateResult",
{
Expand All @@ -932,7 +931,7 @@ def image_progress(sample, step):
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)

def image_done(image, seed, first_seed):
def image_done(image, seed, first_seed, attention_maps_image=None):
if self.canceled.is_set():
raise CanceledException

Expand Down Expand Up @@ -1094,6 +1093,12 @@ def image_done(image, seed, first_seed):
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
eventlet.sleep(0)

parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model, parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)

self.socketio.emit(
"generationResult",
{
Expand All @@ -1106,6 +1111,8 @@ def image_done(image, seed, first_seed):
"height": height,
"boundingBox": original_bounding_box,
"generationMode": generation_parameters["generation_mode"],
"attentionMaps": attention_maps_image_base64_url,
"tokens": tokens,
},
)
eventlet.sleep(0)
Expand All @@ -1117,7 +1124,7 @@ def image_done(image, seed, first_seed):
self.generate.prompt2image(
**generation_parameters,
step_callback=image_progress,
image_callback=image_done,
image_callback=image_done
)

except KeyboardInterrupt:
Expand Down Expand Up @@ -1564,6 +1571,19 @@ def dataURL_to_image(dataURL: str) -> ImageType:
)
return image

"""
Converts an image into a base64 image dataURL.
"""

def image_to_dataURL(image: ImageType) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_base64 = "data:image/png;base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64



"""
Converts a base64 image dataURL into bytes.
Expand Down
42 changes: 22 additions & 20 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import skimage

from omegaconf import OmegaConf

import ldm.invoke.conditioning
from ldm.invoke.generator.base import downsampling
from PIL import Image, ImageOps
from torch import nn
Expand All @@ -40,7 +42,7 @@
from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from ldm.invoke.concepts_lib import Concepts

def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
Expand Down Expand Up @@ -235,7 +237,7 @@ def __init__(
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())

def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
Expand Down Expand Up @@ -329,7 +331,7 @@ def prompt2image(
infill_method = infill_methods[0], # The infill method to use
force_outpaint: bool = False,
enable_image_debugging = False,

**args,
): # eat up additional cruft
"""
Expand Down Expand Up @@ -372,7 +374,7 @@ def prompt2image(
def process_image(image,seed):
image.save(f{'images/seed.png'})

The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
It contains code to create the requested output directory, select a unique informative
name for each image, and write the prompt into the PNG metadata.
"""
Expand Down Expand Up @@ -455,7 +457,7 @@ def process_image(image,seed):
try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=skip_normalize,
skip_normalize_legacy_blend=skip_normalize,
log_tokens =self.log_tokenization
)

Expand Down Expand Up @@ -589,7 +591,7 @@ def apply_postprocessor(
seed = opt.seed or args.seed
if seed is None or seed < 0:
seed = random.randrange(0, np.iinfo(np.uint32).max)

prompt = opt.prompt or args.prompt or ''
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')

Expand All @@ -607,8 +609,8 @@ def apply_postprocessor(
# todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization
skip_normalize_legacy_blend=opt.skip_normalize,
log_tokens =ldm.invoke.conditioning.log_tokenization
)

if tool in ('gfpgan','codeformer','upscale'):
Expand Down Expand Up @@ -641,7 +643,7 @@ def apply_postprocessor(

opt.seed = seed
opt.prompt = prompt

if len(extend_instructions) > 0:
restorer = Outcrop(image,self,)
return restorer.process (
Expand Down Expand Up @@ -683,7 +685,7 @@ def apply_postprocessor(
image_callback = callback,
prefix = prefix
)

elif tool is None:
print(f'* please provide at least one postprocessing option, such as -G or -U')
return None
Expand All @@ -706,13 +708,13 @@ def select_generator(

if embiggen is not None:
return self._make_embiggen()

if inpainting_model_in_use:
return self._make_omnibus()

if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
return self._make_inpaint()

if init_image is not None:
return self._make_img2img()

Expand Down Expand Up @@ -743,7 +745,7 @@ def _make_images(
if self._has_transparency(image):
self._transparency_check_and_warning(image, mask, force_outpaint)
init_mask = self._create_init_mask(image, width, height, fit=fit)

if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False
Expand All @@ -759,7 +761,7 @@ def _make_images(

if init_mask and invert_mask:
init_mask = ImageOps.invert(init_mask)

return init_image,init_mask

# lots o' repeated code here! Turn into a make_func()
Expand Down Expand Up @@ -818,7 +820,7 @@ def load_model(self):
self.set_model(self.model_name)

def set_model(self,model_name):
"""
"""
Given the name of a model defined in models.yaml, will load and initialize it
and return the model object. Previously-used models will be cached.
"""
Expand All @@ -830,7 +832,7 @@ def set_model(self,model_name):
if not cache.valid_model(model_name):
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return self.model

cache.print_vram_usage()

# have to get rid of all references to model in order
Expand All @@ -839,7 +841,7 @@ def set_model(self,model_name):
self.sampler = None
self.generators = {}
gc.collect()

model_data = cache.get_model(model_name)
if model_data is None: # restore previous
model_data = cache.get_model(self.model_name)
Expand All @@ -852,7 +854,7 @@ def set_model(self,model_name):

# uncache generators so they pick up new models
self.generators = {}

seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None:
self.model.embedding_manager.load(
Expand Down Expand Up @@ -901,7 +903,7 @@ def upscale_and_reconstruct(self,
image_callback = None,
prefix = None,
):

for r in image_list:
image, seed = r
try:
Expand All @@ -911,7 +913,7 @@ def upscale_and_reconstruct(self,
if self.gfpgan is None:
print('>> GFPGAN not found. Face restoration is disabled.')
else:
image = self.gfpgan.process(image, strength, seed)
image = self.gfpgan.process(image, strength, seed)
if facetool == 'codeformer':
if self.codeformer is None:
print('>> CodeFormer not found. Face restoration is disabled.')
Expand Down
Loading