Skip to content

Commit 87d7307

Browse files
committed
[Community Pipelines] lpw_stable_diffusion: Add is_cancelled_callback
1 parent 6b185b6 commit 87d7307

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/community/lpw_stable_diffusion.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def __call__(
498498
output_type: Optional[str] = "pil",
499499
return_dict: bool = True,
500500
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
501+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
501502
callback_steps: Optional[int] = 1,
502503
**kwargs,
503504
):
@@ -560,11 +561,15 @@ def __call__(
560561
callback (`Callable`, *optional*):
561562
A function that will be called every `callback_steps` steps during inference. The function will be
562563
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
564+
is_cancelled_callback (`Callable`, *optional*):
565+
A function that will be called every `callback_steps` steps during inference. If the function returns
566+
`True`, the inference will be cancelled.
563567
callback_steps (`int`, *optional*, defaults to 1):
564568
The frequency at which the `callback` function will be called. If not specified, the callback will be
565569
called at every step.
566570
567571
Returns:
572+
`None` if cancelled by `is_cancelled_callback`,
568573
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
569574
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
570575
When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -757,8 +762,11 @@ def __call__(
757762
latents = (init_latents_proper * mask) + (latents * (1 - mask))
758763

759764
# call the callback, if provided
760-
if callback is not None and i % callback_steps == 0:
761-
callback(i, t, latents)
765+
if i % callback_steps == 0:
766+
if callback is not None:
767+
callback(i, t, latents)
768+
if is_cancelled_callback is not None and is_cancelled_callback():
769+
return None
762770

763771
latents = 1 / 0.18215 * latents
764772
image = self.vae.decode(latents).sample

0 commit comments

Comments
 (0)