-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Scheduler design] The pragmatic approach #719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
159e15c
74ae717
fa9667f
c06af2b
9325ca4
46ceb10
f004100
7a4abb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import warnings | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Tuple, Union | ||
|
|
||
|
|
@@ -102,11 +102,36 @@ def __init__( | |
| sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) | ||
| self.sigmas = torch.from_numpy(sigmas) | ||
|
|
||
| # standard deviation of the initial noise distribution | ||
| self.init_noise_sigma = self.sigmas.max() | ||
|
|
||
| # setable values | ||
| self.num_inference_steps = None | ||
| timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() | ||
| self.timesteps = torch.from_numpy(timesteps) | ||
| self.derivatives = [] | ||
| self.is_scale_input_called = False | ||
|
|
||
| def scale_model_input( | ||
| self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] | ||
| ) -> torch.FloatTensor: | ||
| """ | ||
| Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. | ||
|
|
||
| Args: | ||
| sample (`torch.FloatTensor`): input sample | ||
| timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | ||
|
|
||
| Returns: | ||
| `torch.FloatTensor`: scaled input sample | ||
| """ | ||
| if isinstance(timestep, torch.Tensor): | ||
| timestep = timestep.to(self.timesteps.device) | ||
| step_index = (self.timesteps == timestep).nonzero().item() | ||
| sigma = self.sigmas[step_index] | ||
| sample = sample / ((sigma**2 + 1) ** 0.5) | ||
| self.is_scale_input_called = True | ||
| return sample | ||
|
|
||
| def get_lms_coefficient(self, order, t, current_order): | ||
| """ | ||
|
|
@@ -154,7 +179,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
| def step( | ||
| self, | ||
| model_output: torch.FloatTensor, | ||
| timestep: int, | ||
| timestep: Union[float, torch.FloatTensor], | ||
| sample: torch.FloatTensor, | ||
| order: int = 4, | ||
| return_dict: bool = True, | ||
|
|
@@ -165,7 +190,7 @@ def step( | |
|
|
||
| Args: | ||
| model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
| timestep (`int`): current discrete timestep in the diffusion chain. | ||
| timestep (`float`): current timestep in the diffusion chain. | ||
| sample (`torch.FloatTensor`): | ||
| current instance of sample being created by diffusion process. | ||
| order: coefficient for multi-step inference. | ||
|
|
@@ -177,7 +202,21 @@ def step( | |
| When returning a tuple, the first element is the sample tensor. | ||
|
|
||
| """ | ||
| sigma = self.sigmas[timestep] | ||
| if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor): | ||
| warnings.warn( | ||
| f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. " | ||
| "Make sure to pass one of the `scheduler.timesteps`" | ||
| ) | ||
| if not self.is_scale_input_called: | ||
| warnings.warn( | ||
| "The `scale_model_input` function should be called before `step` to ensure correct denoising. " | ||
| "See `StableDiffusionPipeline` for a usage example." | ||
| ) | ||
|
|
||
| if isinstance(timestep, torch.Tensor): | ||
| timestep = timestep.to(self.timesteps.device) | ||
| step_index = (self.timesteps == timestep).nonzero().item() | ||
| sigma = self.sigmas[step_index] | ||
|
|
||
| # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
| pred_original_sample = sample - sigma * model_output | ||
|
|
@@ -189,8 +228,8 @@ def step( | |
| self.derivatives.pop(0) | ||
|
|
||
| # 3. Compute linear multistep coefficients | ||
| order = min(timestep + 1, order) | ||
| lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] | ||
| order = min(step_index + 1, order) | ||
| lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] | ||
|
|
||
| # 4. Compute previous sample based on the derivatives path | ||
| prev_sample = sample + sum( | ||
|
|
@@ -206,12 +245,14 @@ def add_noise( | |
| self, | ||
| original_samples: torch.FloatTensor, | ||
| noise: torch.FloatTensor, | ||
| timesteps: torch.IntTensor, | ||
| timesteps: torch.FloatTensor, | ||
| ) -> torch.FloatTensor: | ||
| sigmas = self.sigmas.to(original_samples.device) | ||
| schedule_timesteps = self.timesteps.to(original_samples.device) | ||
| timesteps = timesteps.to(original_samples.device) | ||
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really dislike what we have to do here, but unfortunately there's no good vectorized alternative to search for multiple indices and keep the order the same.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think it's that bad honestly |
||
|
|
||
| sigma = sigmas[timesteps].flatten() | ||
| sigma = sigmas[step_indices].flatten() | ||
| while len(sigma.shape) < len(original_samples.shape): | ||
| sigma = sigma.unsqueeze(-1) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will pop up in existing community pipelines but won't break them like an exception would. The legacy pipelines can continue using the manual scaling code 👍