Skip to content

Commit f049e70

Browse files
committed
change code style
1 parent 845a7d3 commit f049e70

File tree

4 files changed

+129
-117
lines changed

4 files changed

+129
-117
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
1616
from ...pipeline_flax_utils import FlaxDiffusionPipeline
17-
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxDPMSolverDiscreteScheduler
17+
from ...schedulers import (
18+
FlaxDDIMScheduler,
19+
FlaxDPMSolverDiscreteScheduler,
20+
FlaxLMSDiscreteScheduler,
21+
FlaxPNDMScheduler,
22+
)
1823
from ...utils import logging
1924
from . import FlaxStableDiffusionPipelineOutput
2025
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4348
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
4449
scheduler ([`SchedulerMixin`]):
4550
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
46-
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverDiscreteScheduler`].
51+
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
52+
[`FlaxDPMSolverDiscreteScheduler`].
4753
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
4854
Classification module that estimates whether generated images could be considered offensive or harmful.
4955
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
@@ -57,7 +63,9 @@ def __init__(
5763
text_encoder: FlaxCLIPTextModel,
5864
tokenizer: CLIPTokenizer,
5965
unet: FlaxUNet2DConditionModel,
60-
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler],
66+
scheduler: Union[
67+
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler
68+
],
6169
safety_checker: FlaxStableDiffusionSafetyChecker,
6270
feature_extractor: CLIPFeatureExtractor,
6371
dtype: jnp.dtype = jnp.float32,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from ...pipeline_utils import DiffusionPipeline
1212
from ...schedulers import (
1313
DDIMScheduler,
14+
DPMSolverDiscreteScheduler,
1415
EulerAncestralDiscreteScheduler,
1516
EulerDiscreteScheduler,
1617
LMSDiscreteScheduler,
1718
PNDMScheduler,
18-
DPMSolverDiscreteScheduler,
1919
)
2020
from ...utils import deprecate, logging
2121
from . import StableDiffusionPipelineOutput
@@ -60,7 +60,11 @@ def __init__(
6060
tokenizer: CLIPTokenizer,
6161
unet: UNet2DConditionModel,
6262
scheduler: Union[
63-
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
63+
DDIMScheduler,
64+
PNDMScheduler,
65+
LMSDiscreteScheduler,
66+
EulerDiscreteScheduler,
67+
EulerAncestralDiscreteScheduler,
6468
DPMSolverDiscreteScheduler,
6569
],
6670
safety_checker: StableDiffusionSafetyChecker,

src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
1616

1717
import math
18-
from typing import Optional, Tuple, Union, List
18+
from typing import List, Optional, Tuple, Union
1919

2020
import numpy as np
2121
import torch
@@ -151,7 +151,9 @@ def __init__(
151151
self.num_inference_steps = None
152152
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
153153
self.timesteps = torch.from_numpy(timesteps)
154-
self.model_outputs = [None,] * self.solver_order
154+
self.model_outputs = [
155+
None,
156+
] * self.solver_order
155157
self.lower_order_nums = 0
156158

157159
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
@@ -165,16 +167,20 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
165167
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
166168
"""
167169
self.num_inference_steps = num_inference_steps
168-
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
170+
timesteps = (
171+
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
172+
.round()[::-1][:-1]
173+
.copy()
174+
.astype(np.int64)
175+
)
169176
self.timesteps = torch.from_numpy(timesteps).to(device)
170-
self.model_outputs = [None,] * self.solver_order
177+
self.model_outputs = [
178+
None,
179+
] * self.solver_order
171180
self.lower_order_nums = 0
172181

173182
def convert_model_output(
174-
self,
175-
model_output: torch.FloatTensor,
176-
timestep: int,
177-
sample: torch.FloatTensor
183+
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
178184
) -> torch.FloatTensor:
179185
"""
180186
TODO
@@ -184,9 +190,11 @@ def convert_model_output(
184190
x0_pred = (sample - sigma_t * model_output) / alpha_t
185191
if self.thresholding:
186192
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
187-
p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487).
193+
p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487).
188194
s = torch.quantile(torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), p, dim=1)
189-
s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[(...,) + (None,)*(x0_pred.ndim - 1)]
195+
s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[
196+
(...,) + (None,) * (x0_pred.ndim - 1)
197+
]
190198
x0_pred = torch.clamp(x0_pred, -s, s) / s
191199
return x0_pred
192200
else:
@@ -207,15 +215,9 @@ def dpm_solver_first_order_update(
207215
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
208216
h = lambda_t - lambda_s
209217
if self.predict_x0:
210-
x_t = (
211-
(sigma_t / sigma_s) * sample
212-
- (alpha_t * (torch.exp(-h) - 1.)) * model_output
213-
)
218+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
214219
else:
215-
x_t = (
216-
(alpha_t / alpha_s) * sample
217-
- (sigma_t * (torch.exp(h) - 1.)) * model_output
218-
)
220+
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
219221
return x_t
220222

221223
def multistep_dpm_solver_second_order_update(
@@ -235,32 +237,32 @@ def multistep_dpm_solver_second_order_update(
235237
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
236238
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
237239
r0 = h_0 / h
238-
D0, D1 = m0, (1. / r0) * (m0 - m1)
240+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
239241
if self.predict_x0:
240-
if self.solver_type == 'dpm_solver':
242+
if self.solver_type == "dpm_solver":
241243
x_t = (
242244
(sigma_t / sigma_s0) * sample
243-
- (alpha_t * (torch.exp(-h) - 1.)) * D0
244-
- 0.5 * (alpha_t * (torch.exp(-h) - 1.)) * D1
245+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
246+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
245247
)
246-
elif self.solver_type == 'taylor':
248+
elif self.solver_type == "taylor":
247249
x_t = (
248250
(sigma_t / sigma_s0) * sample
249-
- (alpha_t * (torch.exp(-h) - 1.)) * D0
250-
+ (alpha_t * ((torch.exp(-h) - 1.) / h + 1.)) * D1
251+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
252+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
251253
)
252254
else:
253-
if self.solver_type == 'dpm_solver':
255+
if self.solver_type == "dpm_solver":
254256
x_t = (
255257
(alpha_t / alpha_s0) * sample
256-
- (sigma_t * (torch.exp(h) - 1.)) * D0
257-
- 0.5 * (sigma_t * (torch.exp(h) - 1.)) * D1
258+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
259+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
258260
)
259-
elif self.solver_type == 'taylor':
261+
elif self.solver_type == "taylor":
260262
x_t = (
261263
(alpha_t / alpha_s0) * sample
262-
- (sigma_t * (torch.exp(h) - 1.)) * D0
263-
- (sigma_t * ((torch.exp(h) - 1.) / h - 1.)) * D1
264+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
265+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
264266
)
265267
return x_t
266268

@@ -276,28 +278,33 @@ def multistep_dpm_solver_third_order_update(
276278
"""
277279
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
278280
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
279-
lambda_t, lambda_s0, lambda_s1, lambda_s2 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2]
281+
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
282+
self.lambda_t[t],
283+
self.lambda_t[s0],
284+
self.lambda_t[s1],
285+
self.lambda_t[s2],
286+
)
280287
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
281288
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
282289
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
283290
r0, r1 = h_0 / h, h_1 / h
284291
D0 = m0
285-
D1_0, D1_1 = (1. / r0) * (m0 - m1), (1. / r1) * (m1 - m2)
292+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
286293
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
287-
D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
294+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
288295
if self.predict_x0:
289296
x_t = (
290297
(sigma_t / sigma_s0) * sample
291-
- (alpha_t * (torch.exp(-h) - 1.)) * D0
292-
+ (alpha_t * ((torch.exp(-h) - 1.) / h + 1.)) * D1
293-
- (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5)) * D2
298+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
299+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
300+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
294301
)
295302
else:
296303
x_t = (
297304
(alpha_t / alpha_s0) * sample
298-
- (sigma_t * (torch.exp(h) - 1.)) * D0
299-
- (sigma_t * ((torch.exp(h) - 1.) / h - 1.)) * D1
300-
- (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5)) * D2
305+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
306+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
307+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
301308
)
302309
return x_t
303310

@@ -336,7 +343,7 @@ def step(
336343
denoise_final = (step_index == len(self.timesteps) - 1) and self.denoise_final
337344
denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final
338345

339-
model_output = self.convert_model_output(model_output, timestep, sample)
346+
model_output = self.convert_model_output(model_output, timestep, sample)
340347
for i in range(self.solver_order - 1):
341348
self.model_outputs[i] = self.model_outputs[i + 1]
342349
self.model_outputs[-1] = model_output
@@ -345,10 +352,14 @@ def step(
345352
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
346353
elif self.solver_order == 2 or self.lower_order_nums < 2 or denoise_second:
347354
timestep_list = [self.timesteps[step_index - 1], timestep]
348-
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, timestep_list, prev_timestep, sample)
355+
prev_sample = self.multistep_dpm_solver_second_order_update(
356+
self.model_outputs, timestep_list, prev_timestep, sample
357+
)
349358
else:
350359
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
351-
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, timestep_list, prev_timestep, sample)
360+
prev_sample = self.multistep_dpm_solver_third_order_update(
361+
self.model_outputs, timestep_list, prev_timestep, sample
362+
)
352363

353364
if self.lower_order_nums < self.solver_order:
354365
self.lower_order_nums += 1

0 commit comments

Comments
 (0)