Skip to content

Commit bdbcaa9

Browse files
authored
lpw_stable_diffusion: Add is_cancelled_callback (#1053)
* [Community Pipelines] lpw_stable_diffusion: Add is_cancelled_callback * [Community pipelines] lpw_stable_diffusion_onnx: Add is_cancelled_callback
1 parent 8ee2191 commit bdbcaa9

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

examples/community/lpw_stable_diffusion.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def __call__(
498498
output_type: Optional[str] = "pil",
499499
return_dict: bool = True,
500500
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
501+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
501502
callback_steps: Optional[int] = 1,
502503
**kwargs,
503504
):
@@ -560,11 +561,15 @@ def __call__(
560561
callback (`Callable`, *optional*):
561562
A function that will be called every `callback_steps` steps during inference. The function will be
562563
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
564+
is_cancelled_callback (`Callable`, *optional*):
565+
A function that will be called every `callback_steps` steps during inference. If the function returns
566+
`True`, the inference will be cancelled.
563567
callback_steps (`int`, *optional*, defaults to 1):
564568
The frequency at which the `callback` function will be called. If not specified, the callback will be
565569
called at every step.
566570
567571
Returns:
572+
`None` if cancelled by `is_cancelled_callback`,
568573
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
569574
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
570575
When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -757,8 +762,11 @@ def __call__(
757762
latents = (init_latents_proper * mask) + (latents * (1 - mask))
758763

759764
# call the callback, if provided
760-
if callback is not None and i % callback_steps == 0:
761-
callback(i, t, latents)
765+
if i % callback_steps == 0:
766+
if callback is not None:
767+
callback(i, t, latents)
768+
if is_cancelled_callback is not None and is_cancelled_callback():
769+
return None
762770

763771
latents = 1 / 0.18215 * latents
764772
image = self.vae.decode(latents).sample

examples/community/lpw_stable_diffusion_onnx.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def __call__(
435435
output_type: Optional[str] = "pil",
436436
return_dict: bool = True,
437437
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
438+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
438439
callback_steps: Optional[int] = 1,
439440
**kwargs,
440441
):
@@ -496,11 +497,15 @@ def __call__(
496497
callback (`Callable`, *optional*):
497498
A function that will be called every `callback_steps` steps during inference. The function will be
498499
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
500+
is_cancelled_callback (`Callable`, *optional*):
501+
A function that will be called every `callback_steps` steps during inference. If the function returns
502+
`True`, the inference will be cancelled.
499503
callback_steps (`int`, *optional*, defaults to 1):
500504
The frequency at which the `callback` function will be called. If not specified, the callback will be
501505
called at every step.
502506
503507
Returns:
508+
`None` if cancelled by `is_cancelled_callback`,
504509
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
505510
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
506511
When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -668,8 +673,11 @@ def __call__(
668673
latents = (init_latents_proper * mask) + (latents * (1 - mask))
669674

670675
# call the callback, if provided
671-
if callback is not None and i % callback_steps == 0:
672-
callback(i, t, latents)
676+
if i % callback_steps == 0:
677+
if callback is not None:
678+
callback(i, t, latents)
679+
if is_cancelled_callback is not None and is_cancelled_callback():
680+
return None
673681

674682
latents = 1 / 0.18215 * latents
675683
# image = self.vae_decoder(latent_sample=latents)[0]

0 commit comments

Comments
 (0)