-
Notifications
You must be signed in to change notification settings - Fork 6.5k
fix DDPMScheduler.set_timesteps() when num_inference_steps not a divisor of num_train_timesteps #1835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
* Device to use (e.g. cpu, cuda:0, cuda:1, etc.) * "cuda" if torch.cuda.is_available() else "cpu"
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing this, the PR looks good!
|
|
||
| step_ratio = self.config.num_train_timesteps // self.num_inference_steps | ||
| timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me as well!
* Width was typod as weight * Run Black
* Make xformers optional even if it is available * Raise exception if xformers is used but not available * Rename use_xformers to enable_xformers_memory_efficient_attention * Add a note about xformers in README * Reformat code style
* allow selecting precision to make DB class images addresses #1831 * add prior_generation_precision argument * correct prior_generation_precision's description Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
* unCLIP image variation * remove prior comment re: @pcuenca * stable diffusion -> unCLIP re: @pcuenca * add copy froms re: @patil-suraj
* initial * type hints * update scheduler type hint * add to README * add example generation to README * v -> mix_factor * load scheduler from pretrained
…#1725) * [Dtype] Align automatic dtype * up * up * fix * re-add accelerate
Co-authored-by: Henrik Forstén <[email protected]> * update TI script * make flake happy * fix typo
* Flax: Add components function * Flax: Fix img2img and align with other pipeline * Flax: Fix PRNGKey type * Refactor strength to start_timestep * Fix preprocess images * Fix processed_images dimen * latents.shape -> latents_shape * Fix typo * Remove "static" comment * Remove unnecessary optional types in _generate * Apply doc-builder code style. Co-authored-by: Pedro Cuenca <[email protected]>
* move files a bit * more refactors * fix more * more fixes * fix more onnx * make style * upload * fix * up * fix more * up again * up * small fix * Update src/diffusers/__init__.py Co-authored-by: Pedro Cuenca <[email protected]> * correct Co-authored-by: Pedro Cuenca <[email protected]>
… passed to enable interpolation tasks. (#1858) * [Unclip] Make sure latents can be reused * allow one to directly pass embeddings * up * make unclip for text work * finish allowing to pass embeddings * correct more * make style
* Fix ema decay and clarify nomenclature. * Rename var.
* [Docs] Improve docs * up
update loss computation
| num_inference_steps (`int`): | ||
| the number of diffusion steps used when generating samples with a pre-trained model. | ||
| """ | ||
| num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I think it's actually better to throw an error if num_inference_steps > self.config.num_train_timesteps . It's usually better to educate the user instead of silently correcting a bug IMO.
More specifically, it's just wrong to pass num_inference_steps > self.config.num_train_timesteps since it's impossible for the model to handle it => so the user should understand this and correct the input.
| num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) | |
| if num_inference_steps > self.config.num_train_timesteps: | |
| raise ValueError(f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`: {self.config.num_train_steps} as the unet model trained with this scheduler can only handle maximal {self.config.num_train_steps} timesteps.") | |
| num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. I agree that it's better to throw an error. I'll add the suggested changes.
|
Hey @Joqsan, Thanks a lot for the PR! Could we maybe change the "silent correction" to an error message :-) Apart from this this PR looks nice! |
* allow using non-ema weights for training * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * address more review comment * reorganise a few lines * always pad text to max_length to match original training * ifx collate_fn * remove unused code * don't prepare ema_unet, don't register lr scheduler * style * assert => ValueError * add allow_tf32 * set log level * fix comment Co-authored-by: Pedro Cuenca <[email protected]>
|
Thanks for the suggestion! I pushed a commit adding the error message to the ddpm and ddim files. For consistency's sake I think we should add the error message to the other schedulers too, shouldn't we? If agreed, I can take a look at it. Just not sure whether to include that in this PR or do it in a separate one... |
|
It seems like the git history is messed up sadly :-/ This happens sometimes when incorrectly rebasing the branch I think. I usually just merge "main" into the branches to avoid this. It's not that trivial to solve actually - I always recommend to just open a new PR and close this one (copy-paste the relevant files from your PR) - sorry about this! |
Fix the following issue:
After calling
DDPMScheduler.set_timesteps(num_inference_steps), the number of timesteps doesn't matchnum_inference_steps, ifnum_train_timesteps % num_inference_steps != 0.Example: