Skip to content

Commit 61cc41a

Browse files
Fixes for #1604 (#1605)
* Converts ESRGAN image input to RGB - Also adds typing for image input. - Partially resolves #1604 * ensure there are unmasked pixels before color matching Co-authored-by: Kyle Schouviller <[email protected]>
1 parent 40c3ab0 commit 61cc41a

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

ldm/invoke/generator/base.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,18 @@ def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image
141141
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
142142
np_image_masked = np_image[mask_pixels, :]
143143

144-
init_means = np_init_rgb_pixels_masked.mean(axis=0)
145-
init_std = np_init_rgb_pixels_masked.std(axis=0)
146-
gen_means = np_image_masked.mean(axis=0)
147-
gen_std = np_image_masked.std(axis=0)
148-
149-
# Color correct
150-
np_matched_result = np_image.copy()
151-
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
152-
matched_result = Image.fromarray(np_matched_result, mode='RGB')
144+
if np_init_rgb_pixels_masked.size > 0:
145+
init_means = np_init_rgb_pixels_masked.mean(axis=0)
146+
init_std = np_init_rgb_pixels_masked.std(axis=0)
147+
gen_means = np_image_masked.mean(axis=0)
148+
gen_std = np_image_masked.std(axis=0)
149+
150+
# Color correct
151+
np_matched_result = np_image.copy()
152+
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
153+
matched_result = Image.fromarray(np_matched_result, mode='RGB')
154+
else:
155+
matched_result = Image.fromarray(np_image, mode='RGB')
153156

154157
# Blur the mask out (into init image) by specified amount
155158
if mask_blur_radius > 0:

ldm/invoke/restoration/realesrgan.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ldm.invoke.globals import Globals
77
from PIL import Image
8-
8+
from PIL.Image import Image as ImageType
99

1010
class ESRGAN():
1111
def __init__(self, bg_tile_size=400) -> None:
@@ -41,7 +41,7 @@ def load_esrgan_bg_upsampler(self):
4141

4242
return bg_upsampler
4343

44-
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
44+
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2):
4545
with warnings.catch_warnings():
4646
warnings.filterwarnings('ignore', category=DeprecationWarning)
4747
warnings.filterwarnings('ignore', category=UserWarning)
@@ -62,7 +62,9 @@ def process(self, image, strength: float, seed: str = None, upsampler_scale: int
6262
print(
6363
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
6464
)
65-
65+
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
66+
image = image.convert("RGB")
67+
6668
# REALSRGAN expects a BGR np array; make array and flip channels
6769
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
6870

0 commit comments

Comments
 (0)