@@ -71,6 +71,43 @@ def alpha_bar_fn(t):
7171 return torch .tensor (betas , dtype = torch .float32 )
7272
7373
74+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75+ def rescale_zero_terminal_snr (betas ):
76+ """
77+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78+
79+
80+ Args:
81+ betas (`torch.FloatTensor`):
82+ the betas that the scheduler is being initialized with.
83+
84+ Returns:
85+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
86+ """
87+ # Convert betas to alphas_bar_sqrt
88+ alphas = 1.0 - betas
89+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
90+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
91+
92+ # Store old values.
93+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
94+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
95+
96+ # Shift so the last timestep is zero.
97+ alphas_bar_sqrt -= alphas_bar_sqrt_T
98+
99+ # Scale so the first timestep is back to the old value.
100+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
101+
102+ # Convert alphas_bar_sqrt to betas
103+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
104+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
105+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
106+ betas = 1 - alphas
107+
108+ return betas
109+
110+
74111class DPMSolverMultistepScheduler (SchedulerMixin , ConfigMixin ):
75112 """
76113 `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
@@ -144,6 +181,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
144181 An offset added to the inference steps. You can use a combination of `offset=1` and
145182 `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
146183 Diffusion.
184+ rescale_betas_zero_snr (`bool`, defaults to `False`):
185+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
186+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
187+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
147188 """
148189
149190 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -173,6 +214,7 @@ def __init__(
173214 variance_type : Optional [str ] = None ,
174215 timestep_spacing : str = "linspace" ,
175216 steps_offset : int = 0 ,
217+ rescale_betas_zero_snr : bool = False ,
176218 ):
177219 if algorithm_type in ["dpmsolver" , "sde-dpmsolver" ]:
178220 deprecation_message = f"algorithm_type { algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
@@ -191,8 +233,17 @@ def __init__(
191233 else :
192234 raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
193235
236+ if rescale_betas_zero_snr :
237+ self .betas = rescale_zero_terminal_snr (self .betas )
238+
194239 self .alphas = 1.0 - self .betas
195240 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
241+
242+ if rescale_betas_zero_snr :
243+ # Close to 0 without being 0 so first sigma is not inf
244+ # FP16 smallest positive subnormal works well here
245+ self .alphas_cumprod [- 1 ] = 2 ** - 24
246+
196247 # Currently we only support VP-type noise schedule
197248 self .alpha_t = torch .sqrt (self .alphas_cumprod )
198249 self .sigma_t = torch .sqrt (1 - self .alphas_cumprod )
@@ -895,9 +946,12 @@ def step(
895946 self .model_outputs [i ] = self .model_outputs [i + 1 ]
896947 self .model_outputs [- 1 ] = model_output
897948
949+ # Upcast to avoid precision issues when computing prev_sample
950+ sample = sample .to (torch .float32 )
951+
898952 if self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ]:
899953 noise = randn_tensor (
900- model_output .shape , generator = generator , device = model_output .device , dtype = model_output . dtype
954+ model_output .shape , generator = generator , device = model_output .device , dtype = torch . float32
901955 )
902956 else :
903957 noise = None
@@ -912,6 +966,9 @@ def step(
912966 if self .lower_order_nums < self .config .solver_order :
913967 self .lower_order_nums += 1
914968
969+ # Cast sample back to expected dtype
970+ prev_sample = prev_sample .to (model_output .dtype )
971+
915972 # upon completion increase step index by one
916973 self ._step_index += 1
917974
0 commit comments