1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
14+ import warnings
1515from dataclasses import dataclass
1616from typing import Optional , Tuple , Union
1717
@@ -102,11 +102,36 @@ def __init__(
102102 sigmas = np .concatenate ([sigmas [::- 1 ], [0.0 ]]).astype (np .float32 )
103103 self .sigmas = torch .from_numpy (sigmas )
104104
105+ # standard deviation of the initial noise distribution
106+ self .init_noise_sigma = self .sigmas .max ()
107+
105108 # setable values
106109 self .num_inference_steps = None
107110 timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
108111 self .timesteps = torch .from_numpy (timesteps )
109112 self .derivatives = []
113+ self .is_scale_input_called = False
114+
115+ def scale_model_input (
116+ self , sample : torch .FloatTensor , timestep : Union [float , torch .FloatTensor ]
117+ ) -> torch .FloatTensor :
118+ """
119+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
120+
121+ Args:
122+ sample (`torch.FloatTensor`): input sample
123+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
124+
125+ Returns:
126+ `torch.FloatTensor`: scaled input sample
127+ """
128+ if isinstance (timestep , torch .Tensor ):
129+ timestep = timestep .to (self .timesteps .device )
130+ step_index = (self .timesteps == timestep ).nonzero ().item ()
131+ sigma = self .sigmas [step_index ]
132+ sample = sample / ((sigma ** 2 + 1 ) ** 0.5 )
133+ self .is_scale_input_called = True
134+ return sample
110135
111136 def get_lms_coefficient (self , order , t , current_order ):
112137 """
@@ -154,7 +179,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
154179 def step (
155180 self ,
156181 model_output : torch .FloatTensor ,
157- timestep : int ,
182+ timestep : Union [ float , torch . FloatTensor ] ,
158183 sample : torch .FloatTensor ,
159184 order : int = 4 ,
160185 return_dict : bool = True ,
@@ -165,7 +190,7 @@ def step(
165190
166191 Args:
167192 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
168- timestep (`int `): current discrete timestep in the diffusion chain.
193+ timestep (`float `): current timestep in the diffusion chain.
169194 sample (`torch.FloatTensor`):
170195 current instance of sample being created by diffusion process.
171196 order: coefficient for multi-step inference.
@@ -177,7 +202,21 @@ def step(
177202 When returning a tuple, the first element is the sample tensor.
178203
179204 """
180- sigma = self .sigmas [timestep ]
205+ if not isinstance (timestep , float ) and not isinstance (timestep , torch .FloatTensor ):
206+ warnings .warn (
207+ f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not { type (timestep )} . "
208+ "Make sure to pass one of the `scheduler.timesteps`"
209+ )
210+ if not self .is_scale_input_called :
211+ warnings .warn (
212+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
213+ "See `StableDiffusionPipeline` for a usage example."
214+ )
215+
216+ if isinstance (timestep , torch .Tensor ):
217+ timestep = timestep .to (self .timesteps .device )
218+ step_index = (self .timesteps == timestep ).nonzero ().item ()
219+ sigma = self .sigmas [step_index ]
181220
182221 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
183222 pred_original_sample = sample - sigma * model_output
@@ -189,8 +228,8 @@ def step(
189228 self .derivatives .pop (0 )
190229
191230 # 3. Compute linear multistep coefficients
192- order = min (timestep + 1 , order )
193- lms_coeffs = [self .get_lms_coefficient (order , timestep , curr_order ) for curr_order in range (order )]
231+ order = min (step_index + 1 , order )
232+ lms_coeffs = [self .get_lms_coefficient (order , step_index , curr_order ) for curr_order in range (order )]
194233
195234 # 4. Compute previous sample based on the derivatives path
196235 prev_sample = sample + sum (
@@ -206,12 +245,14 @@ def add_noise(
206245 self ,
207246 original_samples : torch .FloatTensor ,
208247 noise : torch .FloatTensor ,
209- timesteps : torch .IntTensor ,
248+ timesteps : torch .FloatTensor ,
210249 ) -> torch .FloatTensor :
211250 sigmas = self .sigmas .to (original_samples .device )
251+ schedule_timesteps = self .timesteps .to (original_samples .device )
212252 timesteps = timesteps .to (original_samples .device )
253+ step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
213254
214- sigma = sigmas [timesteps ].flatten ()
255+ sigma = sigmas [step_indices ].flatten ()
215256 while len (sigma .shape ) < len (original_samples .shape ):
216257 sigma = sigma .unsqueeze (- 1 )
217258
0 commit comments