Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def __call__(
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
Expand Down Expand Up @@ -560,11 +561,15 @@ def __call__(
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
Expand Down Expand Up @@ -757,8 +762,11 @@ def __call__(
latents = (init_latents_proper * mask) + (latents * (1 - mask))

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None

latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
Expand Down
12 changes: 10 additions & 2 deletions examples/community/lpw_stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def __call__(
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
Expand Down Expand Up @@ -496,11 +497,15 @@ def __call__(
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
Expand Down Expand Up @@ -668,8 +673,11 @@ def __call__(
latents = (init_latents_proper * mask) + (latents * (1 - mask))

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None

latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
Expand Down