1515# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
1616
1717import math
18+ import warnings
1819from typing import Optional , Tuple , Union
1920
2021import numpy as np
@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7475 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
7576 trained_betas (`np.ndarray`, optional):
7677 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
77- tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
7878 skip_prk_steps (`bool`):
7979 allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
8080 before plms steps; defaults to `False`.
81+ set_alpha_to_one (`bool`, default `False`):
82+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
83+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
84+ otherwise it uses the value of alpha at step 0.
85+ steps_offset (`int`, default `0`):
86+ an offset added to the inference steps. You can use a combination of `offset=1` and
87+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
88+ stable diffusion.
89+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
8190
8291 """
8392
@@ -89,8 +98,10 @@ def __init__(
8998 beta_end : float = 0.02 ,
9099 beta_schedule : str = "linear" ,
91100 trained_betas : Optional [np .ndarray ] = None ,
92- tensor_format : str = "pt" ,
93101 skip_prk_steps : bool = False ,
102+ set_alpha_to_one : bool = False ,
103+ steps_offset : int = 0 ,
104+ tensor_format : str = "pt" ,
94105 ):
95106 if trained_betas is not None :
96107 self .betas = np .asarray (trained_betas )
@@ -108,6 +119,8 @@ def __init__(
108119 self .alphas = 1.0 - self .betas
109120 self .alphas_cumprod = np .cumprod (self .alphas , axis = 0 )
110121
122+ self .final_alpha_cumprod = np .array (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
123+
111124 # For now we only support F-PNDM, i.e. the runge-kutta method
112125 # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
113126 # mainly at formula (9), (12), (13) and the Algorithm 2.
@@ -122,31 +135,38 @@ def __init__(
122135 # setable values
123136 self .num_inference_steps = None
124137 self ._timesteps = np .arange (0 , num_train_timesteps )[::- 1 ].copy ()
125- self ._offset = 0
126138 self .prk_timesteps = None
127139 self .plms_timesteps = None
128140 self .timesteps = None
129141
130142 self .tensor_format = tensor_format
131143 self .set_format (tensor_format = tensor_format )
132144
133- def set_timesteps (self , num_inference_steps : int , offset : int = 0 ) -> torch .FloatTensor :
145+ def set_timesteps (self , num_inference_steps : int , ** kwargs ) -> torch .FloatTensor :
134146 """
135147 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
136148
137149 Args:
138150 num_inference_steps (`int`):
139151 the number of diffusion steps used when generating samples with a pre-trained model.
140- offset (`int`):
141- optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
142152 """
153+
154+ offset = self .config .steps_offset
155+
156+ if "offset" in kwargs :
157+ warnings .warn (
158+ "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
159+ " Please pass `steps_offset` to `__init__` instead."
160+ )
161+
162+ offset = kwargs ["offset" ]
163+
143164 self .num_inference_steps = num_inference_steps
144165 step_ratio = self .config .num_train_timesteps // self .num_inference_steps
145166 # creates integer timesteps by multiplying by ratio
146167 # casting to int to avoid issues when num_inference_step is power of 3
147- self ._timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ().tolist ()
148- self ._offset = offset
149- self ._timesteps = np .array ([t + self ._offset for t in self ._timesteps ])
168+ self ._timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()
169+ self ._timesteps += offset
150170
151171 if self .config .skip_prk_steps :
152172 # for some models like stable diffusion the prk steps can/should be skipped to
@@ -231,7 +251,7 @@ def step_prk(
231251 )
232252
233253 diff_to_prev = 0 if self .counter % 2 else self .config .num_train_timesteps // self .num_inference_steps // 2
234- prev_timestep = max ( timestep - diff_to_prev , self . prk_timesteps [ - 1 ])
254+ prev_timestep = timestep - diff_to_prev
235255 timestep = self .prk_timesteps [self .counter // 4 * 4 ]
236256
237257 if self .counter % 4 == 0 :
@@ -293,7 +313,7 @@ def step_plms(
293313 "for more information."
294314 )
295315
296- prev_timestep = max ( timestep - self .config .num_train_timesteps // self .num_inference_steps , 0 )
316+ prev_timestep = timestep - self .config .num_train_timesteps // self .num_inference_steps
297317
298318 if self .counter != 1 :
299319 self .ets .append (model_output )
@@ -323,7 +343,7 @@ def step_plms(
323343
324344 return SchedulerOutput (prev_sample = prev_sample )
325345
326- def _get_prev_sample (self , sample , timestep , timestep_prev , model_output ):
346+ def _get_prev_sample (self , sample , timestep , prev_timestep , model_output ):
327347 # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
328348 # this function computes x_(t−δ) using the formula of (9)
329349 # Note that x_t needs to be added to both sides of the equation
@@ -336,8 +356,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
336356 # sample -> x_t
337357 # model_output -> e_θ(x_t, t)
338358 # prev_sample -> x_(t−δ)
339- alpha_prod_t = self .alphas_cumprod [timestep + 1 - self . _offset ]
340- alpha_prod_t_prev = self .alphas_cumprod [timestep_prev + 1 - self ._offset ]
359+ alpha_prod_t = self .alphas_cumprod [timestep ]
360+ alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
341361 beta_prod_t = 1 - alpha_prod_t
342362 beta_prod_t_prev = 1 - alpha_prod_t_prev
343363
0 commit comments