Skip to content

Commit 170af08

Browse files
samediipatrickvonplatenpatil-suraj
authored
Easily understandable error if inference steps not set before using scheduler (#263) (#264)
* Helpful exception if inference steps not set in schedulers (#263) * Apply suggestions from codereview by patrickvonplaten * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 76985bc commit 170af08

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ def step(
117117
use_clipped_model_output: bool = False,
118118
generator=None,
119119
):
120+
if self.num_inference_steps is None:
121+
raise ValueError(
122+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
123+
)
124+
120125
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
121126
# Ideally, read DDIM paper in-detail understanding
122127

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ def step_prk(
145145
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
146146
solution to the differential equation.
147147
"""
148+
if self.num_inference_steps is None:
149+
raise ValueError(
150+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
151+
)
152+
148153
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
149154
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
150155
timestep = self.prk_timesteps[self.counter // 4 * 4]
@@ -179,6 +184,11 @@ def step_plms(
179184
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
180185
times to approximate the solution.
181186
"""
187+
if self.num_inference_steps is None:
188+
raise ValueError(
189+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
190+
)
191+
182192
if not self.config.skip_prk_steps and len(self.ets) < 3:
183193
raise ValueError(
184194
f"{self.__class__} can only be run AFTER scheduler has been run "

src/diffusers/schedulers/scheduling_sde_ve.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def step_pred(
120120
self.set_seed(seed)
121121
# TODO(Patrick) non-PyTorch
122122

123+
if self.timesteps is None:
124+
raise ValueError(
125+
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
126+
)
127+
123128
timestep = timestep * torch.ones(
124129
sample.shape[0], device=sample.device
125130
) # torch.repeat_interleave(timestep, sample.shape[0])
@@ -155,6 +160,11 @@ def step_correct(
155160
if seed is not None:
156161
self.set_seed(seed)
157162

163+
if self.timesteps is None:
164+
raise ValueError(
165+
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
166+
)
167+
158168
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
159169
# sample noise for correction
160170
noise = self.randn_like(sample)

src/diffusers/schedulers/scheduling_sde_vp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def set_timesteps(self, num_inference_steps):
3535
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
3636

3737
def step_pred(self, score, x, t):
38+
if self.timesteps is None:
39+
raise ValueError(
40+
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
41+
)
42+
3843
# TODO(Patrick) better comments + non-PyTorch
3944
# postprocess model score
4045
log_mean_coeff = (

0 commit comments

Comments
 (0)