2222from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion import StableDiffusionPipeline
2323from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
2424from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
25+ from diffusers .schedulers .scheduling_utils import SchedulerMixin
2526from diffusers .schedulers import DDIMScheduler , LMSDiscreteScheduler , PNDMScheduler
2627from torchvision .transforms .functional import resize as tv_resize
2728from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
@@ -61,10 +62,32 @@ def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings:
6162 mask = einops .repeat (self .mask , 'b c h w -> (repeat b) c h w' , repeat = batch_size )
6263 mask_latents = einops .repeat (self .mask_latents , 'b c h w -> (repeat b) c h w' , repeat = batch_size )
6364 model_input , _ = einops .pack ([latents , mask , mask_latents ], 'b * h w' )
64- # model_input = torch.cat([latents, mask, mask_latents], dim=1)
6565 return self .forward (model_input , t , text_embeddings )
6666
6767
68+ @dataclass
69+ class AddsMaskGuidance :
70+ forward : Callable [[torch .Tensor , torch .Tensor , torch .Tensor ], torch .Tensor ]
71+ mask : torch .FloatTensor
72+ mask_latents : torch .FloatTensor
73+ _scheduler : SchedulerMixin
74+ _noise_func : Callable
75+ _debug : Optional [Callable ] = None
76+
77+ def __call__ (self , latents : torch .FloatTensor , t : torch .Tensor , text_embeddings : torch .FloatTensor ) -> torch .Tensor :
78+ batch_size = latents .size (0 )
79+ mask = einops .repeat (self .mask , 'b c h w -> (repeat b) c h w' , repeat = batch_size )
80+ noise = self ._noise_func (self .mask_latents )
81+ mask_latents = self ._scheduler .add_noise (self .mask_latents , noise , t [0 ]) # .to(dtype=mask_latents.dtype)
82+ mask_latents = einops .repeat (mask_latents , 'b c h w -> (repeat b) c h w' , repeat = batch_size )
83+ # if self._debug:
84+ # self._debug(latents, f"t={t[0]} latents")
85+ masked_input = torch .lerp (mask_latents .to (dtype = latents .dtype ), latents , mask .to (dtype = latents .dtype ))
86+ if self ._debug :
87+ self ._debug (masked_input , f"t={ t [0 ]} lerped" )
88+ return self .forward (masked_input , t , text_embeddings )
89+
90+
6891def image_resized_to_grid_as_tensor (image : PIL .Image .Image , normalize : bool = True , multiple_of = 8 ) -> torch .FloatTensor :
6992 """
7093
@@ -382,17 +405,18 @@ def inpaint_from_embeddings(
382405 latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
383406 latents , init_image_latents = self .prepare_latents_from_image (init_image , latent_timestep , latents_dtype , device , noise_func )
384407
385- if is_inpainting_model (self .unet ):
386- if mask .dim () == 3 :
387- mask = mask .unsqueeze (0 )
388- mask = tv_resize (mask , latents .shape [- 2 :], T .InterpolationMode .BILINEAR )\
389- .to (device = device , dtype = latents_dtype )
408+ if mask .dim () == 3 :
409+ mask = mask .unsqueeze (0 )
410+ mask = tv_resize (mask , latents .shape [- 2 :], T .InterpolationMode .BILINEAR ) \
411+ .to (device = device , dtype = latents_dtype )
390412
413+ if is_inpainting_model (self .unet ):
391414 self .invokeai_diffuser .model_forward_callback = \
392415 AddsMaskLatents (self ._unet_forward , mask , init_image_latents )
393416 else :
394- # FIXME: need to add guidance that applies mask
395- pass
417+ self .invokeai_diffuser .model_forward_callback = \
418+ AddsMaskGuidance (self ._unet_forward , mask , init_image_latents ,
419+ self .scheduler , noise_func ) # self.debug_latents)
396420
397421 result = None
398422
@@ -417,7 +441,7 @@ def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_
417441 init_image = init_image .to (device = device , dtype = dtype )
418442 with torch .inference_mode ():
419443 init_latent_dist = self .vae .encode (init_image ).latent_dist
420- init_latents = init_latent_dist .sample () # FIXME: uses torch.randn. make reproducible!
444+ init_latents = init_latent_dist .sample (). to ( dtype = dtype ) # FIXME: uses torch.randn. make reproducible!
421445 init_latents = 0.18215 * init_latents
422446
423447 noise = noise_func (init_latents )
@@ -456,3 +480,10 @@ def _tokenize(self, prompt: Union[str, List[str]]):
456480 def channels (self ) -> int :
457481 """Compatible with DiffusionWrapper"""
458482 return self .unet .in_channels
483+
484+ def debug_latents (self , latents , msg ):
485+ with torch .inference_mode ():
486+ from ldm .util import debug_image
487+ decoded = self .numpy_to_pil (self .decode_latents (latents ))
488+ for i , img in enumerate (decoded ):
489+ debug_image (img , f"latents { msg } { i + 1 } /{ len (decoded )} " , debug_status = True )
0 commit comments