-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Scheduler] Move predict epsilon to init #1155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e37c3a8
ffdbd34
2dc4837
c8acb31
98e5f3c
fbdf96a
319b042
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,9 @@ | |
|
|
||
| import torch | ||
|
|
||
| from ...configuration_utils import FrozenDict | ||
| from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput | ||
| from ...utils import deprecate | ||
|
|
||
|
|
||
| class DDPMPipeline(DiffusionPipeline): | ||
|
|
@@ -45,7 +47,6 @@ def __call__( | |
| num_inference_steps: int = 1000, | ||
| output_type: Optional[str] = "pil", | ||
| return_dict: bool = True, | ||
| predict_epsilon: bool = True, | ||
| **kwargs, | ||
| ) -> Union[ImagePipelineOutput, Tuple]: | ||
| r""" | ||
|
|
@@ -69,6 +70,16 @@ def __call__( | |
| `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the | ||
| generated images. | ||
| """ | ||
| message = ( | ||
| "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" | ||
| " DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`." | ||
| ) | ||
| predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) | ||
|
|
||
| if predict_epsilon is not None: | ||
| new_config = dict(self.scheduler.config) | ||
| new_config["predict_epsilon"] = predict_epsilon | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pcuenca note that now you should change this into: Here and everywhere else |
||
| self.scheduler._internal_dict = FrozenDict(new_config) | ||
|
|
||
| # Sample gaussian noise to begin loop | ||
| image = torch.randn( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,8 @@ | |
| import numpy as np | ||
| import torch | ||
|
|
||
| from ..configuration_utils import ConfigMixin, register_to_config | ||
| from ..utils import BaseOutput | ||
| from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config | ||
| from ..utils import BaseOutput, deprecate | ||
| from .scheduling_utils import SchedulerMixin | ||
|
|
||
|
|
||
|
|
@@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): | |
| `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. | ||
| clip_sample (`bool`, default `True`): | ||
| option to clip predicted sample between -1 and 1 for numerical stability. | ||
| predict_epsilon (`bool`): | ||
| optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -121,6 +123,7 @@ def __init__( | |
| trained_betas: Optional[np.ndarray] = None, | ||
| variance_type: str = "fixed_small", | ||
| clip_sample: bool = True, | ||
| predict_epsilon: bool = True, | ||
| ): | ||
| if trained_betas is not None: | ||
| self.betas = torch.from_numpy(trained_betas) | ||
|
|
@@ -221,9 +224,9 @@ def step( | |
| model_output: torch.FloatTensor, | ||
| timestep: int, | ||
| sample: torch.FloatTensor, | ||
| predict_epsilon=True, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| generator=None, | ||
| return_dict: bool = True, | ||
| **kwargs, | ||
| ) -> Union[DDPMSchedulerOutput, Tuple]: | ||
| """ | ||
| Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
|
|
@@ -234,8 +237,6 @@ def step( | |
| timestep (`int`): current discrete timestep in the diffusion chain. | ||
| sample (`torch.FloatTensor`): | ||
| current instance of sample being created by diffusion process. | ||
| predict_epsilon (`bool`): | ||
| optional flag to use when model predicts the samples directly instead of the noise, epsilon. | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| generator: random number generator. | ||
| return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class | ||
|
|
||
|
|
@@ -245,6 +246,16 @@ def step( | |
| returning a tuple, the first element is the sample tensor. | ||
|
|
||
| """ | ||
| message = ( | ||
| "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" | ||
| " DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`." | ||
| ) | ||
| predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) | ||
| if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: | ||
| new_config = dict(self.config) | ||
| new_config["predict_epsilon"] = predict_epsilon | ||
| self._internal_dict = FrozenDict(new_config) | ||
|
|
||
| t = timestep | ||
|
|
||
| if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: | ||
|
|
@@ -260,7 +271,7 @@ def step( | |
|
|
||
| # 2. compute predicted original sample from predicted noise also called | ||
| # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | ||
| if predict_epsilon: | ||
| if self.config.predict_epsilon: | ||
| pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||
| else: | ||
| pred_original_sample = model_output | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to align naming all over the codebase