Skip to content

Commit cf83856

Browse files
jamestiotioPrathik Rao
authored andcommitted
Add callback parameters for Stable Diffusion pipelines (huggingface#521)
* Add callback parameters for Stable Diffusion pipelines Signed-off-by: James R T <[email protected]> * Lint code with `black --preview` Signed-off-by: James R T <[email protected]> * Refactor callback implementation for Stable Diffusion pipelines * Fix missing imports Signed-off-by: James R T <[email protected]> * Fix documentation format Signed-off-by: James R T <[email protected]> * Add kwargs parameter to standardize with other pipelines Signed-off-by: James R T <[email protected]> * Modify Stable Diffusion pipeline callback parameters Signed-off-by: James R T <[email protected]> * Remove useless imports Signed-off-by: James R T <[email protected]> * Change types for timestep and onnx latents * Fix docstring style * Return decode_latents and run_safety_checker back into __call__ * Remove unused imports * Add intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T <[email protected]> * Fix intermediate state tests for Stable Diffusion pipelines Signed-off-by: James R T <[email protected]> Signed-off-by: James R T <[email protected]>
1 parent 819b573 commit cf83856

File tree

5 files changed

+259
-12
lines changed

5 files changed

+259
-12
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import warnings
3-
from typing import List, Optional, Union
3+
from typing import Callable, List, Optional, Union
44

55
import torch
66

@@ -122,6 +122,8 @@ def __call__(
122122
latents: Optional[torch.FloatTensor] = None,
123123
output_type: Optional[str] = "pil",
124124
return_dict: bool = True,
125+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
126+
callback_steps: Optional[int] = 1,
125127
**kwargs,
126128
):
127129
r"""
@@ -159,6 +161,12 @@ def __call__(
159161
return_dict (`bool`, *optional*, defaults to `True`):
160162
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
161163
plain tuple.
164+
callback (`Callable`, *optional*):
165+
A function that will be called every `callback_steps` steps during inference. The function will be
166+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
167+
callback_steps (`int`, *optional*, defaults to 1):
168+
The frequency at which the `callback` function will be called. If not specified, the callback will be
169+
called at every step.
162170
163171
Returns:
164172
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -178,6 +186,14 @@ def __call__(
178186
if height % 8 != 0 or width % 8 != 0:
179187
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
180188

189+
if (callback_steps is None) or (
190+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
191+
):
192+
raise ValueError(
193+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
194+
f" {type(callback_steps)}."
195+
)
196+
181197
# get prompt text embeddings
182198
text_inputs = self.tokenizer(
183199
prompt,
@@ -277,14 +293,16 @@ def __call__(
277293
else:
278294
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
279295

280-
# scale and decode the image latents with vae
296+
# call the callback, if provided
297+
if callback is not None and i % callback_steps == 0:
298+
callback(i, t, latents)
299+
281300
latents = 1 / 0.18215 * latents
282301
image = self.vae.decode(latents).sample
283302

284303
image = (image / 2 + 0.5).clamp(0, 1)
285304
image = image.cpu().permute(0, 2, 3, 1).numpy()
286305

287-
# run safety checker
288306
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
289307
image, has_nsfw_concept = self.safety_checker(
290308
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import warnings
3-
from typing import List, Optional, Union
3+
from typing import Callable, List, Optional, Union
44

55
import numpy as np
66
import torch
@@ -133,6 +133,9 @@ def __call__(
133133
generator: Optional[torch.Generator] = None,
134134
output_type: Optional[str] = "pil",
135135
return_dict: bool = True,
136+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
137+
callback_steps: Optional[int] = 1,
138+
**kwargs,
136139
):
137140
r"""
138141
Function invoked when calling the pipeline for generation.
@@ -170,6 +173,12 @@ def __call__(
170173
return_dict (`bool`, *optional*, defaults to `True`):
171174
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
172175
plain tuple.
176+
callback (`Callable`, *optional*):
177+
A function that will be called every `callback_steps` steps during inference. The function will be
178+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
179+
callback_steps (`int`, *optional*, defaults to 1):
180+
The frequency at which the `callback` function will be called. If not specified, the callback will be
181+
called at every step.
173182
174183
Returns:
175184
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -188,6 +197,14 @@ def __call__(
188197
if strength < 0 or strength > 1:
189198
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
190199

200+
if (callback_steps is None) or (
201+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
202+
):
203+
raise ValueError(
204+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
205+
f" {type(callback_steps)}."
206+
)
207+
191208
# set timesteps
192209
self.scheduler.set_timesteps(num_inference_steps)
193210

@@ -265,6 +282,7 @@ def __call__(
265282
latents = init_latents
266283

267284
t_start = max(num_inference_steps - init_timestep + offset, 0)
285+
268286
# Some schedulers like PNDM have timesteps as arrays
269287
# It's more optimzed to move all timesteps to correct device beforehand
270288
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
@@ -295,14 +313,16 @@ def __call__(
295313
else:
296314
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
297315

298-
# scale and decode the image latents with vae
316+
# call the callback, if provided
317+
if callback is not None and i % callback_steps == 0:
318+
callback(i, t, latents)
319+
299320
latents = 1 / 0.18215 * latents
300321
image = self.vae.decode(latents).sample
301322

302323
image = (image / 2 + 0.5).clamp(0, 1)
303324
image = image.cpu().permute(0, 2, 3, 1).numpy()
304325

305-
# run safety checker
306326
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
307327
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
308328

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import warnings
3-
from typing import List, Optional, Union
3+
from typing import Callable, List, Optional, Union
44

55
import numpy as np
66
import torch
@@ -149,6 +149,9 @@ def __call__(
149149
generator: Optional[torch.Generator] = None,
150150
output_type: Optional[str] = "pil",
151151
return_dict: bool = True,
152+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
153+
callback_steps: Optional[int] = 1,
154+
**kwargs,
152155
):
153156
r"""
154157
Function invoked when calling the pipeline for generation.
@@ -190,6 +193,12 @@ def __call__(
190193
return_dict (`bool`, *optional*, defaults to `True`):
191194
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
192195
plain tuple.
196+
callback (`Callable`, *optional*):
197+
A function that will be called every `callback_steps` steps during inference. The function will be
198+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
199+
callback_steps (`int`, *optional*, defaults to 1):
200+
The frequency at which the `callback` function will be called. If not specified, the callback will be
201+
called at every step.
193202
194203
Returns:
195204
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -208,6 +217,14 @@ def __call__(
208217
if strength < 0 or strength > 1:
209218
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
210219

220+
if (callback_steps is None) or (
221+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
222+
):
223+
raise ValueError(
224+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
225+
f" {type(callback_steps)}."
226+
)
227+
211228
# set timesteps
212229
self.scheduler.set_timesteps(num_inference_steps)
213230

@@ -297,7 +314,9 @@ def __call__(
297314
extra_step_kwargs["eta"] = eta
298315

299316
latents = init_latents
317+
300318
t_start = max(num_inference_steps - init_timestep + offset, 0)
319+
301320
# Some schedulers like PNDM have timesteps as arrays
302321
# It's more optimzed to move all timesteps to correct device beforehand
303322
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
@@ -331,14 +350,16 @@ def __call__(
331350

332351
latents = (init_latents_proper * mask) + (latents * (1 - mask))
333352

334-
# scale and decode the image latents with vae
353+
# call the callback, if provided
354+
if callback is not None and i % callback_steps == 0:
355+
callback(i, t, latents)
356+
335357
latents = 1 / 0.18215 * latents
336358
image = self.vae.decode(latents).sample
337359

338360
image = (image / 2 + 0.5).clamp(0, 1)
339361
image = image.cpu().permute(0, 2, 3, 1).numpy()
340362

341-
# run safety checker
342363
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
343364
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
344365

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import List, Optional, Union
2+
from typing import Callable, List, Optional, Union
33

44
import numpy as np
55

@@ -56,6 +56,8 @@ def __call__(
5656
latents: Optional[np.ndarray] = None,
5757
output_type: Optional[str] = "pil",
5858
return_dict: bool = True,
59+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
60+
callback_steps: Optional[int] = 1,
5961
**kwargs,
6062
):
6163
if isinstance(prompt, str):
@@ -68,6 +70,14 @@ def __call__(
6870
if height % 8 != 0 or width % 8 != 0:
6971
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
7072

73+
if (callback_steps is None) or (
74+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
75+
):
76+
raise ValueError(
77+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
78+
f" {type(callback_steps)}."
79+
)
80+
7181
# get prompt text embeddings
7282
text_inputs = self.tokenizer(
7383
prompt,
@@ -151,14 +161,18 @@ def __call__(
151161
else:
152162
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
153163

154-
# scale and decode the image latents with vae
164+
latents = np.array(latents)
165+
166+
# call the callback, if provided
167+
if callback is not None and i % callback_steps == 0:
168+
callback(i, t, latents)
169+
155170
latents = 1 / 0.18215 * latents
156171
image = self.vae_decoder(latent_sample=latents)[0]
157172

158173
image = np.clip(image / 2 + 0.5, 0, 1)
159174
image = image.transpose((0, 2, 3, 1))
160175

161-
# run safety checker
162176
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
163177
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
164178

0 commit comments

Comments
 (0)