@@ -92,6 +92,43 @@ def alpha_bar_fn(t):
9292 return torch .tensor (betas , dtype = torch .float32 )
9393
9494
95+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96+ def rescale_zero_terminal_snr (betas ):
97+ """
98+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99+
100+
101+ Args:
102+ betas (`torch.FloatTensor`):
103+ the betas that the scheduler is being initialized with.
104+
105+ Returns:
106+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
107+ """
108+ # Convert betas to alphas_bar_sqrt
109+ alphas = 1.0 - betas
110+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
111+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
112+
113+ # Store old values.
114+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
115+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
116+
117+ # Shift so the last timestep is zero.
118+ alphas_bar_sqrt -= alphas_bar_sqrt_T
119+
120+ # Scale so the first timestep is back to the old value.
121+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
122+
123+ # Convert alphas_bar_sqrt to betas
124+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
125+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
126+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
127+ betas = 1 - alphas
128+
129+ return betas
130+
131+
95132class EulerDiscreteScheduler (SchedulerMixin , ConfigMixin ):
96133 """
97134 Euler scheduler.
@@ -128,6 +165,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
128165 An offset added to the inference steps. You can use a combination of `offset=1` and
129166 `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
130167 Diffusion.
168+ rescale_betas_zero_snr (`bool`, defaults to `False`):
169+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
170+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
171+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
131172 """
132173
133174 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -149,6 +190,7 @@ def __init__(
149190 timestep_spacing : str = "linspace" ,
150191 timestep_type : str = "discrete" , # can be "discrete" or "continuous"
151192 steps_offset : int = 0 ,
193+ rescale_betas_zero_snr : bool = False ,
152194 ):
153195 if trained_betas is not None :
154196 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -163,9 +205,17 @@ def __init__(
163205 else :
164206 raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
165207
208+ if rescale_betas_zero_snr :
209+ self .betas = rescale_zero_terminal_snr (self .betas )
210+
166211 self .alphas = 1.0 - self .betas
167212 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
168213
214+ if rescale_betas_zero_snr :
215+ # Close to 0 without being 0 so first sigma is not inf
216+ # FP16 smallest positive subnormal works well here
217+ self .alphas_cumprod [- 1 ] = 2 ** - 24
218+
169219 sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
170220 timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
171221
@@ -420,6 +470,9 @@ def step(
420470 if self .step_index is None :
421471 self ._init_step_index (timestep )
422472
473+ # Upcast to avoid precision issues when computing prev_sample
474+ sample = sample .to (torch .float32 )
475+
423476 sigma = self .sigmas [self .step_index ]
424477
425478 gamma = min (s_churn / (len (self .sigmas ) - 1 ), 2 ** 0.5 - 1 ) if s_tmin <= sigma <= s_tmax else 0.0
@@ -456,6 +509,9 @@ def step(
456509
457510 prev_sample = sample + derivative * dt
458511
512+ # Cast sample back to model compatible dtype
513+ prev_sample = prev_sample .to (model_output .dtype )
514+
459515 # upon completion increase step index by one
460516 self ._step_index += 1
461517
0 commit comments