From 87d73075eefc6754b60bc82340739c9602ec97cd Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Fri, 28 Oct 2022 21:08:33 -0400 Subject: [PATCH 1/2] [Community Pipelines] lpw_stable_diffusion: Add is_cancelled_callback --- examples/community/lpw_stable_diffusion.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 1b2279de720c..74aed2fec86f 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -498,6 +498,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -560,11 +561,15 @@ def __call__( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a @@ -757,8 +762,11 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample From 6e50e3d3c4ee14e3f1198005e6a0f1265de05ebf Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Fri, 28 Oct 2022 23:10:19 -0400 Subject: [PATCH 2/2] [Community pipelines] lpw_stable_diffusion_onnx: Add is_cancelled_callback --- examples/community/lpw_stable_diffusion_onnx.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 37f03c86f29d..69b942f9ef1a 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -435,6 +435,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -496,11 +497,15 @@ def __call__( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a @@ -668,8 +673,11 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0]