diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 1b2279de720c..74aed2fec86f 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -498,6 +498,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -560,11 +561,15 @@ def __call__( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a @@ -757,8 +762,11 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 37f03c86f29d..69b942f9ef1a 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -435,6 +435,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -496,11 +497,15 @@ def __call__( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a @@ -668,8 +673,11 @@ def __call__( latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0]