@@ -202,11 +202,6 @@ def step(
202202 When returning a tuple, the first element is the sample tensor.
203203
204204 """
205- if not isinstance (timestep , float ) and not isinstance (timestep , torch .FloatTensor ):
206- warnings .warn (
207- f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not { type (timestep )} . "
208- "Make sure to pass one of the `scheduler.timesteps`"
209- )
210205 if not self .is_scale_input_called :
211206 warnings .warn (
212207 "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
@@ -215,7 +210,18 @@ def step(
215210
216211 if isinstance (timestep , torch .Tensor ):
217212 timestep = timestep .to (self .timesteps .device )
218- step_index = (self .timesteps == timestep ).nonzero ().item ()
213+ if (
214+ isinstance (timestep , int )
215+ or isinstance (timestep , torch .IntTensor )
216+ or isinstance (timestep , torch .LongTensor )
217+ ):
218+ warnings .warn (
219+ "Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
220+ " 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
221+ )
222+ step_index = timestep
223+ else :
224+ step_index = (self .timesteps == timestep ).nonzero ().item ()
219225 sigma = self .sigmas [step_index ]
220226
221227 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -250,7 +256,14 @@ def add_noise(
250256 sigmas = self .sigmas .to (original_samples .device )
251257 schedule_timesteps = self .timesteps .to (original_samples .device )
252258 timesteps = timesteps .to (original_samples .device )
253- step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
259+ if isinstance (timesteps , torch .IntTensor ) or isinstance (timesteps , torch .LongTensor ):
260+ warnings .warn (
261+ "Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
262+ " version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
263+ )
264+ step_indices = timesteps
265+ else :
266+ step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
254267
255268 sigma = sigmas [step_indices ].flatten ()
256269 while len (sigma .shape ) < len (original_samples .shape ):
0 commit comments