Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions ldm/invoke/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,18 @@ def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]

init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)

# Color correct
np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
if np_init_rgb_pixels_masked.size > 0:
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)

# Color correct
np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
else:
matched_result = Image.fromarray(np_image, mode='RGB')

# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
Expand Down
8 changes: 5 additions & 3 deletions ldm/invoke/restoration/realesrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ldm.invoke.globals import Globals
from PIL import Image

from PIL.Image import Image as ImageType

class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
Expand Down Expand Up @@ -41,7 +41,7 @@ def load_esrgan_bg_upsampler(self):

return bg_upsampler

def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
Expand All @@ -62,7 +62,9 @@ def process(self, image, strength: float, seed: str = None, upsampler_scale: int
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)

# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")

# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]

Expand Down