2121import torch
2222from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
2323
24+ from ...image_processor import VaeImageProcessor
2425from ...loaders import TextualInversionLoaderMixin
2526from ...models import AutoencoderKL , UNet2DConditionModel
2627from ...models .attention_processor import AttnProcessor2_0 , LoRAXFormersAttnProcessor , XFormersAttnProcessor
@@ -125,6 +126,7 @@ def __init__(
125126 watermarker = watermarker ,
126127 feature_extractor = feature_extractor ,
127128 )
129+ self .image_processor = VaeImageProcessor (vae_scale_factor = 64 , resample = "bicubic" )
128130 self .register_to_config (max_noise_level = max_noise_level )
129131
130132 def enable_sequential_cpu_offload (self , gpu_id = 0 ):
@@ -432,14 +434,15 @@ def check_inputs(
432434 if (
433435 not isinstance (image , torch .Tensor )
434436 and not isinstance (image , PIL .Image .Image )
437+ and not isinstance (image , np .ndarray )
435438 and not isinstance (image , list )
436439 ):
437440 raise ValueError (
438- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is { type (image )} "
441+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray` or `list` but is { type (image )} "
439442 )
440443
441- # verify batch size of prompt and image are same if image is a list or tensor
442- if isinstance (image , list ) or isinstance (image , torch .Tensor ):
444+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
445+ if isinstance (image , list ) or isinstance (image , torch .Tensor ) or isinstance ( image , np . ndarray ) :
443446 if isinstance (prompt , str ):
444447 batch_size = 1
445448 else :
@@ -483,7 +486,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
483486 def __call__ (
484487 self ,
485488 prompt : Union [str , List [str ]] = None ,
486- image : Union [torch .FloatTensor , PIL .Image .Image , List [PIL .Image .Image ]] = None ,
489+ image : Union [
490+ torch .FloatTensor ,
491+ PIL .Image .Image ,
492+ np .ndarray ,
493+ List [torch .FloatTensor ],
494+ List [PIL .Image .Image ],
495+ List [np .ndarray ],
496+ ] = None ,
487497 num_inference_steps : int = 75 ,
488498 guidance_scale : float = 9.0 ,
489499 noise_level : int = 20 ,
@@ -506,7 +516,7 @@ def __call__(
506516 prompt (`str` or `List[str]`, *optional*):
507517 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
508518 instead.
509- image (`PIL.Image.Image` or List[` PIL.Image.Image`] or `torch.FloatTensor `):
519+ image (`torch.FloatTensor`, ` PIL.Image.Image`, `np.ndarray`, ` List[torch.FloatTensor]`, `List[ PIL.Image.Image]`, or `List[np.ndarray] `):
510520 `Image`, or tensor representing an image batch which will be upscaled. *
511521 num_inference_steps (`int`, *optional*, defaults to 50):
512522 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -627,7 +637,7 @@ def __call__(
627637 )
628638
629639 # 4. Preprocess image
630- image = preprocess (image )
640+ image = self . image_processor . preprocess (image )
631641 image = image .to (dtype = prompt_embeds .dtype , device = device )
632642
633643 # 5. set timesteps
@@ -723,25 +733,24 @@ def __call__(
723733 else :
724734 latents = latents .float ()
725735
726- # 11. Convert to PIL
727- if output_type == "pil" :
728- image = self .decode_latents (latents )
729-
736+ if not output_type == "latent" :
737+ image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
730738 image , has_nsfw_concept , _ = self .run_safety_checker (image , device , prompt_embeds .dtype )
731-
732- image = self .numpy_to_pil (image )
733-
734- # 11. Apply watermark
735- if self .watermarker is not None :
736- image = self .watermarker .apply_watermark (image )
737- elif output_type == "pt" :
738- latents = 1 / self .vae .config .scaling_factor * latents
739- image = self .vae .decode (latents ).sample
740- has_nsfw_concept = None
741739 else :
742- image = self . decode_latents ( latents )
740+ image = latents
743741 has_nsfw_concept = None
744742
743+ if has_nsfw_concept is None :
744+ do_denormalize = [True ] * image .shape [0 ]
745+ else :
746+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept ]
747+
748+ image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
749+
750+ # 11. Apply watermark
751+ if output_type == "pil" and self .watermarker is not None :
752+ image = self .watermarker .apply_watermark (image )
753+
745754 # Offload last model to CPU
746755 if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
747756 self .final_offload_hook .offload ()
0 commit comments