Skip to content

Commit 3dacbb9

Browse files
trained_betas ignored in some schedulers (#635)
* correcting the beta value assignment * updating DDIM and LMSDiscreteFlax schedulers * bringing back the changes that were lost as part of main branch merge
1 parent f10576a commit 3dacbb9

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131

132132
if trained_betas is not None:
133133
self.betas = torch.from_numpy(trained_betas)
134-
if beta_schedule == "linear":
134+
elif beta_schedule == "linear":
135135
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
136136
elif beta_schedule == "scaled_linear":
137137
# this schedule is very specific to the latent diffusion model.

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686

8787
if trained_betas is not None:
8888
self.betas = torch.from_numpy(trained_betas)
89-
if beta_schedule == "linear":
89+
elif beta_schedule == "linear":
9090
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
9191
elif beta_schedule == "scaled_linear":
9292
# this schedule is very specific to the latent diffusion model.

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
):
7575
if trained_betas is not None:
7676
self.betas = jnp.asarray(trained_betas)
77-
if beta_schedule == "linear":
77+
elif beta_schedule == "linear":
7878
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
7979
elif beta_schedule == "scaled_linear":
8080
# this schedule is very specific to the latent diffusion model.

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111

112112
if trained_betas is not None:
113113
self.betas = torch.from_numpy(trained_betas)
114-
if beta_schedule == "linear":
114+
elif beta_schedule == "linear":
115115
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
116116
elif beta_schedule == "scaled_linear":
117117
# this schedule is very specific to the latent diffusion model.

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
):
133133
if trained_betas is not None:
134134
self.betas = jnp.asarray(trained_betas)
135-
if beta_schedule == "linear":
135+
elif beta_schedule == "linear":
136136
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
137137
elif beta_schedule == "scaled_linear":
138138
# this schedule is very specific to the latent diffusion model.

0 commit comments

Comments
 (0)