Skip to content

Commit f3570d8

Browse files
committed
inpainting for the normal model [WIP]
This seems to be performing well until the LAST STEP, at which point it dissolves to confetti.
1 parent b2664e8 commit f3570d8

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

ldm/invoke/generator/diffusers_pipeline.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
2323
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
2424
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
25+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
2526
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
2627
from torchvision.transforms.functional import resize as tv_resize
2728
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -61,10 +62,32 @@ def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings:
6162
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
6263
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
6364
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
64-
# model_input = torch.cat([latents, mask, mask_latents], dim=1)
6565
return self.forward(model_input, t, text_embeddings)
6666

6767

68+
@dataclass
69+
class AddsMaskGuidance:
70+
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
71+
mask: torch.FloatTensor
72+
mask_latents: torch.FloatTensor
73+
_scheduler: SchedulerMixin
74+
_noise_func: Callable
75+
_debug: Optional[Callable] = None
76+
77+
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
78+
batch_size = latents.size(0)
79+
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
80+
noise = self._noise_func(self.mask_latents)
81+
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
82+
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
83+
# if self._debug:
84+
# self._debug(latents, f"t={t[0]} latents")
85+
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
86+
if self._debug:
87+
self._debug(masked_input, f"t={t[0]} lerped")
88+
return self.forward(masked_input, t, text_embeddings)
89+
90+
6891
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
6992
"""
7093
@@ -382,17 +405,18 @@ def inpaint_from_embeddings(
382405
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
383406
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
384407

385-
if is_inpainting_model(self.unet):
386-
if mask.dim() == 3:
387-
mask = mask.unsqueeze(0)
388-
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)\
389-
.to(device=device, dtype=latents_dtype)
408+
if mask.dim() == 3:
409+
mask = mask.unsqueeze(0)
410+
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
411+
.to(device=device, dtype=latents_dtype)
390412

413+
if is_inpainting_model(self.unet):
391414
self.invokeai_diffuser.model_forward_callback = \
392415
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
393416
else:
394-
# FIXME: need to add guidance that applies mask
395-
pass
417+
self.invokeai_diffuser.model_forward_callback = \
418+
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
419+
self.scheduler, noise_func) # self.debug_latents)
396420

397421
result = None
398422

@@ -417,7 +441,7 @@ def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_
417441
init_image = init_image.to(device=device, dtype=dtype)
418442
with torch.inference_mode():
419443
init_latent_dist = self.vae.encode(init_image).latent_dist
420-
init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible!
444+
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
421445
init_latents = 0.18215 * init_latents
422446

423447
noise = noise_func(init_latents)
@@ -456,3 +480,10 @@ def _tokenize(self, prompt: Union[str, List[str]]):
456480
def channels(self) -> int:
457481
"""Compatible with DiffusionWrapper"""
458482
return self.unet.in_channels
483+
484+
def debug_latents(self, latents, msg):
485+
with torch.inference_mode():
486+
from ldm.util import debug_image
487+
decoded = self.numpy_to_pil(self.decode_latents(latents))
488+
for i, img in enumerate(decoded):
489+
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)

0 commit comments

Comments
 (0)