@@ -498,6 +498,7 @@ def __call__(
498498 output_type : Optional [str ] = "pil" ,
499499 return_dict : bool = True ,
500500 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
501+ is_cancelled_callback : Optional [Callable [[], bool ]] = None ,
501502 callback_steps : Optional [int ] = 1 ,
502503 ** kwargs ,
503504 ):
@@ -560,11 +561,15 @@ def __call__(
560561 callback (`Callable`, *optional*):
561562 A function that will be called every `callback_steps` steps during inference. The function will be
562563 called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
564+ is_cancelled_callback (`Callable`, *optional*):
565+ A function that will be called every `callback_steps` steps during inference. If the function returns
566+ `True`, the inference will be cancelled.
563567 callback_steps (`int`, *optional*, defaults to 1):
564568 The frequency at which the `callback` function will be called. If not specified, the callback will be
565569 called at every step.
566570
567571 Returns:
572+ `None` if cancelled by `is_cancelled_callback`,
568573 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
569574 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
570575 When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -757,8 +762,11 @@ def __call__(
757762 latents = (init_latents_proper * mask ) + (latents * (1 - mask ))
758763
759764 # call the callback, if provided
760- if callback is not None and i % callback_steps == 0 :
761- callback (i , t , latents )
765+ if i % callback_steps == 0 :
766+ if callback is not None :
767+ callback (i , t , latents )
768+ if is_cancelled_callback is not None and is_cancelled_callback ():
769+ return None
762770
763771 latents = 1 / 0.18215 * latents
764772 image = self .vae .decode (latents ).sample
0 commit comments