From d3d6d69dc73e2aa5aba8331cd70e80cc2035f7ab Mon Sep 17 00:00:00 2001 From: James R T Date: Thu, 15 Sep 2022 12:49:46 +0800 Subject: [PATCH 01/14] Add callback parameters for Stable Diffusion pipelines Signed-off-by: James R T --- .../pipeline_stable_diffusion.py | 27 ++++++++++++++++++- .../pipeline_stable_diffusion_img2img.py | 27 ++++++++++++++++++- .../pipeline_stable_diffusion_inpaint.py | 27 ++++++++++++++++++- .../pipeline_stable_diffusion_onnx.py | 21 ++++++++++++++- 4 files changed, 98 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f02fa114a8e1..3d64caead2fa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import torch @@ -103,6 +103,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable] = None, + callback_frequency: Optional[int] = None, **kwargs, ): r""" @@ -140,6 +142,12 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_frequency` steps during inference. The function will be + called with the following arguments: `callback(step: int, image: List[PIL.Image.Image])`. + callback_frequency (`int`, *optional*): + The frequency at which the `callback` function will be called. If `None`, the callback will be called + after every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -171,6 +179,11 @@ def __call__( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + raise ValueError( + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + ) + # get prompt text embeddings text_input = self.tokenizer( prompt, @@ -259,6 +272,18 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # call the callback, if provided + if callback is not None: + if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): + # scale and decode the image latents with vae + current_latents = 1 / 0.18215 * latents + image = self.vae.decode(current_latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.numpy_to_pil(image) + callback(i, image) + # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 69e2fc36fadf..70e12566090d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -113,6 +113,8 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable] = None, + callback_frequency: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -150,6 +152,12 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_frequency` steps during inference. The function will be + called with the following arguments: `callback(step: int, image: List[PIL.Image.Image])`. + callback_frequency (`int`, *optional*): + The frequency at which the `callback` function will be called. If `None`, the callback will be called + after every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -168,6 +176,11 @@ def __call__( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + raise ValueError( + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + ) + # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {} @@ -271,6 +284,18 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # call the callback, if provided + if callback is not None: + if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): + # scale and decode the image latents with vae + current_latents = 1 / 0.18215 * latents + image = self.vae.decode(current_latents.to(self.vae.dtype)).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.numpy_to_pil(image) + callback(i, image) + # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents.to(self.vae.dtype)).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b529598c70dc..d2da03c4ae2a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -133,6 +133,8 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable] = None, + callback_frequency: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -173,6 +175,12 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_frequency` steps during inference. The function will be + called with the following arguments: `callback(step: int, image: List[PIL.Image.Image])`. + callback_frequency (`int`, *optional*): + The frequency at which the `callback` function will be called. If `None`, the callback will be called + after every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -191,6 +199,11 @@ def __call__( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + raise ValueError( + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + ) + # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {} @@ -289,6 +302,18 @@ def __call__( init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) latents = (init_latents_proper * mask) + (latents * (1 - mask)) + # call the callback, if provided + if callback is not None: + if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): + # scale and decode the image latents with vae + current_latents = 1 / 0.18215 * latents + image = self.vae.decode(current_latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.numpy_to_pil(image) + callback(i, image) + # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 7ff3ff22fc21..9a1f347fbf4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np @@ -53,6 +53,8 @@ def __call__( latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable] = None, + callback_frequency: Optional[int] = None, **kwargs, ): if isinstance(prompt, str): @@ -65,6 +67,11 @@ def __call__( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + raise ValueError( + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + ) + # get prompt text embeddings text_input = self.tokenizer( prompt, @@ -145,6 +152,18 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # call the callback, if provided + if callback is not None: + if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): + # scale and decode the image latents with vae + current_latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=current_latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + image = self.numpy_to_pil(image) + callback(i, image) + # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae_decoder(latent_sample=latents)[0] From 04c1a0357739d8ae4b2be4444c4ff0ac6abee0e9 Mon Sep 17 00:00:00 2001 From: James R T Date: Thu, 15 Sep 2022 13:14:06 +0800 Subject: [PATCH 02/14] Lint code with `black --preview` Signed-off-by: James R T --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 ++- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 3 ++- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 3 ++- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3d64caead2fa..2ce7f551af43 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -181,7 +181,8 @@ def __call__( if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" + f" {type(callback_frequency)}." ) # get prompt text embeddings diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 70e12566090d..feca3577401b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -178,7 +178,8 @@ def __call__( if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" + f" {type(callback_frequency)}." ) # set timesteps diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d2da03c4ae2a..3f4cc36aae33 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -201,7 +201,8 @@ def __call__( if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" + f" {type(callback_frequency)}." ) # set timesteps diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 9a1f347fbf4d..b00f60f2baf9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -69,7 +69,8 @@ def __call__( if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type {type(callback_frequency)}." + f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" + f" {type(callback_frequency)}." ) # get prompt text embeddings From 29b35b172be83985038ff54f7798571cdee417b2 Mon Sep 17 00:00:00 2001 From: James R T Date: Fri, 16 Sep 2022 13:22:46 +0800 Subject: [PATCH 03/14] Refactor callback implementation for Stable Diffusion pipelines --- .../pipeline_stable_diffusion.py | 84 +++++++++++++------ .../pipeline_stable_diffusion_img2img.py | 83 ++++++++++++------ .../pipeline_stable_diffusion_inpaint.py | 83 ++++++++++++------ .../pipeline_stable_diffusion_onnx.py | 51 ++++++----- 4 files changed, 201 insertions(+), 100 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2ce7f551af43..12e5140a7241 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,7 +1,8 @@ import inspect import warnings -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union +import numpy as np import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -90,6 +91,43 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + @torch.no_grad() + def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: + r""" + Scale and decode the latent representations into images using the VAE. + + Args: + latents (`torch.FloatTensor`): + Latent representations to decode into images. + + Returns: + `np.ndarray`: Decoded images. + """ + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return image + + @torch.no_grad() + def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: + r""" + Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be + raised and a black image will be returned instead. + + Args: + image (`np.ndarray`): + Images to run the safety checker on. + + Returns: + image (`np.ndarray`): Images that has been processed by the safety checker. + has_nsfw_concept (`List[bool]`): Boolean array indicating whether the images contain NSFW content. + """ + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + return image, has_nsfw_concept + @torch.no_grad() def __call__( self, @@ -103,8 +141,10 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable] = None, - callback_frequency: Optional[int] = None, + callback: Optional[ + Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] + ] = None, + callback_frequency: Optional[int] = 1, **kwargs, ): r""" @@ -144,10 +184,11 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_frequency` steps during inference. The function will be - called with the following arguments: `callback(step: int, image: List[PIL.Image.Image])`. - callback_frequency (`int`, *optional*): - The frequency at which the `callback` function will be called. If `None`, the callback will be called - after every step. + called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: torch.FloatTensor, + image: Union[List[PIL.Image.Image], np.ndarray])`. + callback_frequency (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -179,7 +220,9 @@ def __call__( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + if (callback_frequency is None) or ( + callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + ): raise ValueError( f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" f" {type(callback_frequency)}." @@ -274,27 +317,16 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if callback is not None: - if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): - # scale and decode the image latents with vae - current_latents = 1 / 0.18215 * latents - image = self.vae.decode(current_latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + if callback is not None and i % callback_frequency == 0: + image = self.decode_latents(latents) + image = self.run_safety_checker(image)[0] + if output_type == "pil": image = self.numpy_to_pil(image) - callback(i, image) - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + callback(i, t, latents, image) - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.decode_latents(latents) - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + image, has_nsfw_concept = self.run_safety_checker(image) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index feca3577401b..405751f4f141 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,5 +1,5 @@ import inspect -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -101,6 +101,43 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) + @torch.no_grad() + def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: + r""" + Scale and decode the latent representations into images using the VAE. + + Args: + latents (`torch.FloatTensor`): + Latent representations to decode into images. + + Returns: + `np.ndarray`: Decoded images. + """ + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents.to(self.vae.dtype)).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return image + + @torch.no_grad() + def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: + r""" + Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be + raised and a black image will be returned instead. + + Args: + image (`np.ndarray`): + Images to run the safety checker on. + + Returns: + image (`np.ndarray`): Images that has been processed by the safety checker. + has_nsfw_concept (`List[bool]`): Boolean array indicating whether the images contain NSFW content. + """ + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + return image, has_nsfw_concept + @torch.no_grad() def __call__( self, @@ -113,8 +150,10 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable] = None, - callback_frequency: Optional[int] = None, + callback: Optional[ + Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] + ] = None, + callback_frequency: Optional[int] = 1, ): r""" Function invoked when calling the pipeline for generation. @@ -154,10 +193,11 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_frequency` steps during inference. The function will be - called with the following arguments: `callback(step: int, image: List[PIL.Image.Image])`. - callback_frequency (`int`, *optional*): - The frequency at which the `callback` function will be called. If `None`, the callback will be called - after every step. + called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: torch.FloatTensor, + image: Union[List[PIL.Image.Image], np.ndarray])`. + callback_frequency (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -176,7 +216,9 @@ def __call__( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + if (callback_frequency is None) or ( + callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + ): raise ValueError( f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" f" {type(callback_frequency)}." @@ -286,27 +328,16 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if callback is not None: - if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): - # scale and decode the image latents with vae - current_latents = 1 / 0.18215 * latents - image = self.vae.decode(current_latents.to(self.vae.dtype)).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + if callback is not None and i % callback_frequency == 0: + image = self.decode_latents(latents) + image = self.run_safety_checker(image)[0] + if output_type == "pil": image = self.numpy_to_pil(image) - callback(i, image) - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents.to(self.vae.dtype)).sample + callback(i, t, latents, image) - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.decode_latents(latents) - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + image, has_nsfw_concept = self.run_safety_checker(image) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3f4cc36aae33..8f17f588da30 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,5 +1,5 @@ import inspect -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -120,6 +120,43 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) + @torch.no_grad() + def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: + r""" + Scale and decode the latent representations into images using the VAE. + + Args: + latents (`torch.FloatTensor`): + Latent representations to decode into images. + + Returns: + `np.ndarray`: Decoded images. + """ + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return image + + @torch.no_grad() + def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: + r""" + Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be + raised and a black image will be returned instead. + + Args: + image (`np.ndarray`): + Images to run the safety checker on. + + Returns: + image (`np.ndarray`): Images that has been processed by the safety checker. + has_nsfw_concept (`List[bool]`): Boolean array indicating whether the images contain NSFW content. + """ + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + return image, has_nsfw_concept + @torch.no_grad() def __call__( self, @@ -133,8 +170,10 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable] = None, - callback_frequency: Optional[int] = None, + callback: Optional[ + Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] + ] = None, + callback_frequency: Optional[int] = 1, ): r""" Function invoked when calling the pipeline for generation. @@ -177,10 +216,11 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_frequency` steps during inference. The function will be - called with the following arguments: `callback(step: int, image: List[PIL.Image.Image])`. - callback_frequency (`int`, *optional*): - The frequency at which the `callback` function will be called. If `None`, the callback will be called - after every step. + called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: torch.FloatTensor, + image: Union[List[PIL.Image.Image], np.ndarray])`. + callback_frequency (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -199,7 +239,9 @@ def __call__( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + if (callback_frequency is None) or ( + callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + ): raise ValueError( f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" f" {type(callback_frequency)}." @@ -304,27 +346,16 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if callback is not None: - if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): - # scale and decode the image latents with vae - current_latents = 1 / 0.18215 * latents - image = self.vae.decode(current_latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + if callback is not None and i % callback_frequency == 0: + image = self.decode_latents(latents) + image = self.run_safety_checker(image)[0] + if output_type == "pil": image = self.numpy_to_pil(image) - callback(i, image) - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + callback(i, t, latents, image) - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.decode_latents(latents) - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + image, has_nsfw_concept = self.run_safety_checker(image) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index b00f60f2baf9..548631f4c924 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -1,7 +1,8 @@ import inspect -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np +import torch from transformers import CLIPFeatureExtractor, CLIPTokenizer @@ -42,6 +43,19 @@ def __init__( feature_extractor=feature_extractor, ) + def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + return image + + def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + return image, has_nsfw_concept + def __call__( self, prompt: Union[str, List[str]], @@ -53,8 +67,10 @@ def __call__( latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable] = None, - callback_frequency: Optional[int] = None, + callback: Optional[ + Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] + ] = None, + callback_frequency: Optional[int] = 1, **kwargs, ): if isinstance(prompt, str): @@ -67,7 +83,9 @@ def __call__( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_frequency is not None and (callback_frequency <= 0 or not isinstance(callback_frequency, int)): + if (callback_frequency is None) or ( + callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + ): raise ValueError( f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" f" {type(callback_frequency)}." @@ -154,27 +172,16 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if callback is not None: - if (callback_frequency is None) or (callback_frequency is not None and i % callback_frequency == 0): - # scale and decode the image latents with vae - current_latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=current_latents)[0] - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) + if callback is not None and i % callback_frequency == 0: + image = self.decode_latents(latents) + image = self.run_safety_checker(image)[0] + if output_type == "pil": image = self.numpy_to_pil(image) - callback(i, image) - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] + callback(i, t, latents, image) - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) + image = self.decode_latents(latents) - # run safety checker - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + image, has_nsfw_concept = self.run_safety_checker(image) if output_type == "pil": image = self.numpy_to_pil(image) From 2115bacfe74b1fb260d31253619e297ed6693604 Mon Sep 17 00:00:00 2001 From: James R T Date: Fri, 16 Sep 2022 13:25:05 +0800 Subject: [PATCH 04/14] Fix missing imports Signed-off-by: James R T --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 + .../pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 12e5140a7241..19a2519e7e2a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -5,6 +5,7 @@ import numpy as np import torch +import PIL from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 548631f4c924..c9b82cfd433a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -4,6 +4,7 @@ import numpy as np import torch +import PIL from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...onnx_utils import OnnxRuntimeModel From fb7f465be215de5e75048cb2a5019705f9a36d83 Mon Sep 17 00:00:00 2001 From: James R T Date: Fri, 16 Sep 2022 14:01:00 +0800 Subject: [PATCH 05/14] Fix documentation format Signed-off-by: James R T --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++++---- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 8 ++++---- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 19a2519e7e2a..5a8c3a51a32c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -122,8 +122,8 @@ def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]] Images to run the safety checker on. Returns: - image (`np.ndarray`): Images that has been processed by the safety checker. - has_nsfw_concept (`List[bool]`): Boolean array indicating whether the images contain NSFW content. + `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the + safety checker. The second element is a boolean array indicating whether the images contain NSFW content. """ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) @@ -185,8 +185,8 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_frequency` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: torch.FloatTensor, - image: Union[List[PIL.Image.Image], np.ndarray])`. + called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: + torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`. callback_frequency (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 405751f4f141..964b7d0d6347 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -131,8 +131,8 @@ def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]] Images to run the safety checker on. Returns: - image (`np.ndarray`): Images that has been processed by the safety checker. - has_nsfw_concept (`List[bool]`): Boolean array indicating whether the images contain NSFW content. + `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the + safety checker. The second element is a boolean array indicating whether the images contain NSFW content. """ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) @@ -193,8 +193,8 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_frequency` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: torch.FloatTensor, - image: Union[List[PIL.Image.Image], np.ndarray])`. + called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: + torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`. callback_frequency (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 8f17f588da30..dcad13e48f15 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -150,8 +150,8 @@ def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]] Images to run the safety checker on. Returns: - image (`np.ndarray`): Images that has been processed by the safety checker. - has_nsfw_concept (`List[bool]`): Boolean array indicating whether the images contain NSFW content. + `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the + safety checker. The second element is a boolean array indicating whether the images contain NSFW content. """ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) @@ -216,8 +216,8 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_frequency` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: torch.FloatTensor, - image: Union[List[PIL.Image.Image], np.ndarray])`. + called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: + torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`. callback_frequency (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. From 068246319c162a3d5b14d0a2cd1db6d35ea8252c Mon Sep 17 00:00:00 2001 From: James R T Date: Sat, 17 Sep 2022 18:41:10 +0800 Subject: [PATCH 06/14] Add kwargs parameter to standardize with other pipelines Signed-off-by: James R T --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 1 + .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 964b7d0d6347..f67f1c6750ae 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -154,6 +154,7 @@ def __call__( Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] ] = None, callback_frequency: Optional[int] = 1, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index dcad13e48f15..d63b83e11d79 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -174,6 +174,7 @@ def __call__( Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] ] = None, callback_frequency: Optional[int] = 1, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. From 886a473b806efc3e4ed20717adbbe9b2e762158f Mon Sep 17 00:00:00 2001 From: James R T Date: Fri, 23 Sep 2022 20:23:00 +0800 Subject: [PATCH 07/14] Modify Stable Diffusion pipeline callback parameters Signed-off-by: James R T --- .../pipeline_stable_diffusion.py | 28 +++++++---------- .../pipeline_stable_diffusion_img2img.py | 29 ++++++++---------- .../pipeline_stable_diffusion_inpaint.py | 30 ++++++++----------- .../pipeline_stable_diffusion_onnx.py | 22 +++++--------- 4 files changed, 44 insertions(+), 65 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5c271d72aa69..aebf9bfd27a1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -158,10 +158,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[ - Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] - ] = None, - callback_frequency: Optional[int] = 1, + callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, ): r""" @@ -200,10 +198,10 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): - A function that will be called every `callback_frequency` steps during inference. The function will be + A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: - torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`. - callback_frequency (`int`, *optional*, defaults to 1): + torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -237,12 +235,12 @@ def __call__( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_frequency is None) or ( - callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" - f" {type(callback_frequency)}." + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." ) # get prompt text embeddings @@ -329,12 +327,8 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if callback is not None and i % callback_frequency == 0: - image = self.decode_latents(latents) - image = self.run_safety_checker(image)[0] - if output_type == "pil": - image = self.numpy_to_pil(image) - callback(i, t, latents, image) + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c4e5fbf6a998..d02421042b6c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -167,10 +167,8 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[ - Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] - ] = None, - callback_frequency: Optional[int] = 1, + callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, ): r""" @@ -210,10 +208,10 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): - A function that will be called every `callback_frequency` steps during inference. The function will be + A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: - torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`. - callback_frequency (`int`, *optional*, defaults to 1): + torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -234,12 +232,12 @@ def __call__( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_frequency is None) or ( - callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" - f" {type(callback_frequency)}." + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." ) # set timesteps @@ -311,6 +309,7 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): t_index = t_start + i @@ -338,12 +337,8 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if callback is not None and i % callback_frequency == 0: - image = self.decode_latents(latents) - image = self.run_safety_checker(image)[0] - if output_type == "pil": - image = self.numpy_to_pil(image) - callback(i, t, latents, image) + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index e8bb27247663..850be223087a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -187,10 +187,8 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[ - Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] - ] = None, - callback_frequency: Optional[int] = 1, + callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, ): r""" @@ -234,10 +232,10 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): - A function that will be called every `callback_frequency` steps during inference. The function will be + A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: - torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`. - callback_frequency (`int`, *optional*, defaults to 1): + torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -258,12 +256,12 @@ def __call__( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_frequency is None) or ( - callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" - f" {type(callback_frequency)}." + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." ) # set timesteps @@ -347,7 +345,9 @@ def __call__( extra_step_kwargs["eta"] = eta latents = init_latents + t_start = max(num_inference_steps - init_timestep + offset, 0) + for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): t_index = t_start + i # expand the latents if we are doing classifier free guidance @@ -378,12 +378,8 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if callback is not None and i % callback_frequency == 0: - image = self.decode_latents(latents) - image = self.run_safety_checker(image)[0] - if output_type == "pil": - image = self.numpy_to_pil(image) - callback(i, t, latents, image) + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 72fd1cdc2719..3ae120e190dd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -68,10 +68,8 @@ def __call__( latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[ - Callable[[int, np.ndarray, torch.FloatTensor, Union[List[PIL.Image.Image], np.ndarray]], None] - ] = None, - callback_frequency: Optional[int] = 1, + callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, ): if isinstance(prompt, str): @@ -84,12 +82,12 @@ def __call__( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_frequency is None) or ( - callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( - f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type" - f" {type(callback_frequency)}." + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." ) # get prompt text embeddings @@ -168,12 +166,8 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if callback is not None and i % callback_frequency == 0: - image = self.decode_latents(latents) - image = self.run_safety_checker(image)[0] - if output_type == "pil": - image = self.numpy_to_pil(image) - callback(i, t, latents, image) + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) image = self.decode_latents(latents) From 5281a515318acf8d31c5c203d62ad2f5971b828a Mon Sep 17 00:00:00 2001 From: James R T Date: Fri, 23 Sep 2022 20:27:04 +0800 Subject: [PATCH 08/14] Remove useless imports Signed-off-by: James R T --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 - .../pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index aebf9bfd27a1..d8c756f76184 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -5,7 +5,6 @@ import numpy as np import torch -import PIL from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 3ae120e190dd..f5cc9ccb2c21 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -4,7 +4,6 @@ import numpy as np import torch -import PIL from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...onnx_utils import OnnxRuntimeModel From 90dc12ba21a3c43935520cfd7faac80279c4dc6e Mon Sep 17 00:00:00 2001 From: James R T Date: Thu, 29 Sep 2022 16:13:31 +0800 Subject: [PATCH 09/14] Change types for timestep and onnx latents --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++-- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 4 ++-- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 ++-- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 4 +++- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f25b6990cc44..dc065618284b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -157,7 +157,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -198,7 +198,7 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 5597997f07c9..b619c7f4026e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -167,7 +167,7 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -209,7 +209,7 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 8a6730f3340a..0d2020645051 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -187,7 +187,7 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -233,7 +233,7 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: np.ndarray, latents: + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index f5cc9ccb2c21..52c55d7d599b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -67,7 +67,7 @@ def __call__( latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -164,6 +164,8 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = np.array(latents) + # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) From e568e6ce42c662a0c76ec94d5095c7c44a5806e4 Mon Sep 17 00:00:00 2001 From: James R T Date: Thu, 29 Sep 2022 16:22:37 +0800 Subject: [PATCH 10/14] Fix docstring style --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 +-- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 3 +-- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index dc065618284b..dbc4597597e6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -198,8 +198,7 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: - torch.FloatTensor)`. + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index b619c7f4026e..de4adff5d125 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -209,8 +209,7 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: - torch.FloatTensor)`. + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 0d2020645051..c6d77b2771c4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -233,8 +233,7 @@ def __call__( plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: - torch.FloatTensor)`. + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. From fc3b377466dd97e28bf4f1bbd89652b0f56bda7c Mon Sep 17 00:00:00 2001 From: James R T Date: Thu, 29 Sep 2022 16:42:04 +0800 Subject: [PATCH 11/14] Return decode_latents and run_safety_checker back into __call__ --- .../pipeline_stable_diffusion.py | 46 +++---------------- .../pipeline_stable_diffusion_img2img.py | 46 +++---------------- .../pipeline_stable_diffusion_inpaint.py | 46 +++---------------- .../pipeline_stable_diffusion_onnx.py | 22 +++------ 4 files changed, 28 insertions(+), 132 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index dbc4597597e6..5a39951f605f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -107,43 +107,6 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - @torch.no_grad() - def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: - r""" - Scale and decode the latent representations into images using the VAE. - - Args: - latents (`torch.FloatTensor`): - Latent representations to decode into images. - - Returns: - `np.ndarray`: Decoded images. - """ - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - return image - - @torch.no_grad() - def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: - r""" - Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be - raised and a black image will be returned instead. - - Args: - image (`np.ndarray`): - Images to run the safety checker on. - - Returns: - `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the - safety checker. The second element is a boolean array indicating whether the images contain NSFW content. - """ - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) - return image, has_nsfw_concept - @torch.no_grad() def __call__( self, @@ -328,9 +291,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - image = self.decode_latents(latents) + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample - image, has_nsfw_concept = self.run_safety_checker(image) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index de4adff5d125..bc32c23f53ed 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -118,43 +118,6 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) - @torch.no_grad() - def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: - r""" - Scale and decode the latent representations into images using the VAE. - - Args: - latents (`torch.FloatTensor`): - Latent representations to decode into images. - - Returns: - `np.ndarray`: Decoded images. - """ - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - return image - - @torch.no_grad() - def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: - r""" - Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be - raised and a black image will be returned instead. - - Args: - image (`np.ndarray`): - Images to run the safety checker on. - - Returns: - `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the - safety checker. The second element is a boolean array indicating whether the images contain NSFW content. - """ - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) - return image, has_nsfw_concept - @torch.no_grad() def __call__( self, @@ -339,9 +302,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - image = self.decode_latents(latents) + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample - image, has_nsfw_concept = self.run_safety_checker(image) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c6d77b2771c4..835a3037a6bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -137,43 +137,6 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) - @torch.no_grad() - def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: - r""" - Scale and decode the latent representations into images using the VAE. - - Args: - latents (`torch.FloatTensor`): - Latent representations to decode into images. - - Returns: - `np.ndarray`: Decoded images. - """ - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - return image - - @torch.no_grad() - def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: - r""" - Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be - raised and a black image will be returned instead. - - Args: - image (`np.ndarray`): - Images to run the safety checker on. - - Returns: - `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the - safety checker. The second element is a boolean array indicating whether the images contain NSFW content. - """ - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) - return image, has_nsfw_concept - @torch.no_grad() def __call__( self, @@ -380,9 +343,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - image = self.decode_latents(latents) + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample - image, has_nsfw_concept = self.run_safety_checker(image) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 52c55d7d599b..a9b4d47eeb95 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -43,19 +43,6 @@ def __init__( feature_extractor=feature_extractor, ) - def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray: - latents = 1 / 0.18215 * latents - image = self.vae_decoder(latent_sample=latents)[0] - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - return image - - def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) - return image, has_nsfw_concept - def __call__( self, prompt: Union[str, List[str]], @@ -170,9 +157,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - image = self.decode_latents(latents) + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) - image, has_nsfw_concept = self.run_safety_checker(image) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) if output_type == "pil": image = self.numpy_to_pil(image) From 21484793abadae7360f3591342311fd918399e8e Mon Sep 17 00:00:00 2001 From: James R T Date: Thu, 29 Sep 2022 16:48:38 +0800 Subject: [PATCH 12/14] Remove unused imports --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 +-- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 3 +-- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5a39951f605f..718f19232ddf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,8 +1,7 @@ import inspect import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union -import numpy as np import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index bc32c23f53ed..b5b140b7d7dc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import numpy as np import torch diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 835a3037a6bb..8e8e9d3310f2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import numpy as np import torch diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index a9b4d47eeb95..92043fb32d40 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -1,8 +1,7 @@ import inspect -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import numpy as np -import torch from transformers import CLIPFeatureExtractor, CLIPTokenizer From ddbdec7967f3552b9e18c9e2271fdb5014717801 Mon Sep 17 00:00:00 2001 From: James R T Date: Sat, 1 Oct 2022 19:01:07 +0800 Subject: [PATCH 13/14] Add intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T --- tests/test_pipelines.py | 162 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index dddf42bd03f2..805989896559 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1387,3 +1387,165 @@ def test_stable_diffusion_onnx(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_text2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = np.array(latents) + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-1.2277, -0.3692, -0.2123, -1.3709, -1.4505, -0.6718, -0.3112, -1.2481, -1.0674] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + pipe( + prompt=prompt, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 51 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_img2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = np.array(latents) + assert latents.shape == (1, 4, 64, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([0.5486, 0.8705, 1.4053, 1.6771, 2.0729, 0.7256, 1.5693, -0.1298, -1.3520]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + + pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 38 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_inpaint_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = np.array(latents) + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.4155, -0.4140, 1.1430, -2.0722, 2.2523, -1.8766, -0.4917, 0.3338, 0.9667] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + + pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 38 + + @slow + def test_stable_diffusion_onnx_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.6254, -0.2742, -1.0710, 0.2296, -1.1683, 0.6913, -2.0605, -0.0682, 0.9700] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionOnnxPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CUDAExecutionProvider" + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "Andromeda galaxy in a bottle" + + np.random.seed(0) + pipe(prompt=prompt, num_inference_steps=50, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) + assert test_callback_fn.has_been_called + assert number_of_steps == 51 From fe05ea2eea78a0789b53e94b86fde6272a49a5a4 Mon Sep 17 00:00:00 2001 From: James R T Date: Sun, 2 Oct 2022 13:30:42 +0800 Subject: [PATCH 14/14] Fix intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T --- tests/test_pipelines.py | 90 +++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index ab46fd12b854..d0d78171378e 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1446,31 +1446,35 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No nonlocal number_of_steps number_of_steps += 1 if step == 0: - latents = np.array(latents) + latents = latents.detach().cpu().numpy() assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array( - [-1.2277, -0.3692, -0.2123, -1.3709, -1.4505, -0.6718, -0.3112, -1.2481, -1.0674] + [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False - pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16 + ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "Andromeda galaxy in a bottle" generator = torch.Generator(device=torch_device).manual_seed(0) - pipe( - prompt=prompt, - num_inference_steps=50, - guidance_scale=7.5, - generator=generator, - callback=test_callback_fn, - callback_steps=1, - ) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) assert test_callback_fn.has_been_called assert number_of_steps == 51 @@ -1484,10 +1488,10 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No nonlocal number_of_steps number_of_steps += 1 if step == 0: - latents = np.array(latents) + latents = latents.detach().cpu().numpy() assert latents.shape == (1, 4, 64, 96) latents_slice = latents[0, -3:, -3:, -1] - expected_slice = np.array([0.5486, 0.8705, 1.4053, 1.6771, 2.0729, 0.7256, 1.5693, -0.1298, -1.3520]) + expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530]) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False @@ -1498,23 +1502,27 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No ) init_image = init_image.resize((768, 512)) - pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16 + ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "A fantasy landscape, trending on artstation" generator = torch.Generator(device=torch_device).manual_seed(0) - pipe( - prompt=prompt, - init_image=init_image, - strength=0.75, - num_inference_steps=50, - guidance_scale=7.5, - generator=generator, - callback=test_callback_fn, - callback_steps=1, - ) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) assert test_callback_fn.has_been_called assert number_of_steps == 38 @@ -1528,11 +1536,11 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No nonlocal number_of_steps number_of_steps += 1 if step == 0: - latents = np.array(latents) + latents = latents.detach().cpu().numpy() assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array( - [-0.4155, -0.4140, 1.1430, -2.0722, 2.2523, -1.8766, -0.4917, 0.3338, 0.9667] + [-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 @@ -1547,24 +1555,28 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" ) - pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16 + ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "A red cat sitting on a park bench" generator = torch.Generator(device=torch_device).manual_seed(0) - pipe( - prompt=prompt, - init_image=init_image, - mask_image=mask_image, - strength=0.75, - num_inference_steps=50, - guidance_scale=7.5, - generator=generator, - callback=test_callback_fn, - callback_steps=1, - ) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) assert test_callback_fn.has_been_called assert number_of_steps == 38 @@ -1587,7 +1599,7 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: test_callback_fn.has_been_called = False pipe = StableDiffusionOnnxPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CUDAExecutionProvider" + "CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None)