diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ffda93f1721..f1f7e7f274d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -21,10 +21,13 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. + latents (`List[torch.Tensor]`, *optional*, returned when `output_latents=True` and `return_dict=True` is passed) + List (one element for each diffusion step) of `torch.Tensor` of shape `(batch_size, in_channels, height // 8, width // 8)` """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] + latents: Optional[List[torch.Tensor]] = None if is_transformers_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f02fa114a8e1..92e91423b2c7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -90,10 +90,9 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], torch.Tensor], height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, @@ -103,14 +102,17 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + output_latents: bool = False, + enable_grad: bool = False, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str`, `List[str]` or `torch.Tensor`): + The prompt or prompts to guide the image generation. If a `torch.Tensor` is provided, it should + have the shape (sequence len, embedding dim) or (batch size, sequence len, embedding dim). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -140,6 +142,11 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + output_latents (`bool`, *optional*, defaults to `False`): + Whether or not to return the latents from all the diffusion steps. See `latents` under returned tensors + for more details. + enable_grad (`bool`, *optional*, defaults to `False`): + Whether or not to enable gradient calculation during diffusion process. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -161,10 +168,26 @@ def __call__( device = "cuda" if torch.cuda.is_available() else "cpu" self.to(device) + # enable/disable grad + was_grad_enabled = torch.is_grad_enabled() + torch.set_grad_enabled(enable_grad) + if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) + elif torch.is_tensor(prompt): + if len(prompt.shape) == 2: + # Add batch dimension + prompt = prompt.unsqueeze(0) + + if len(prompt.shape) != 3: + raise ValueError( + f"If `prompt` is of type `torch.Tensor`, it is expected to have a 2 dimensions " + f"(sequence len, embedding dim) or 3 dimensions (batch size, sequence len, embedding dim), " + f"but found tensor with shape {prompt.shape}" + ) + batch_size = prompt.shape[0] else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -172,14 +195,17 @@ def __call__( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + if torch.is_tensor(prompt): + text_embeddings = prompt + else: + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -187,7 +213,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] + max_length = text_embeddings.shape[-2] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) @@ -237,6 +263,7 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + all_latents = [latents] if output_latents else None for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -259,6 +286,10 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + if output_latents: + # save latents from all diffusion steps + all_latents.append(latents) + # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample @@ -276,4 +307,7 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + # reset + torch.set_grad_enabled(was_grad_enabled) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, latents=all_latents)