-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples #334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Checked that all slow tests run as expected ✔️ - some tests are failing but they are currently also failing on master |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really comprehensive work, very impressive! We should merge it soon and deal with the breaking changes as it permeates the whole codebase.
We also need to adapt the training examples when we merge.
src/diffusers/models/vae.py
Outdated
| `DiagonalGaussianDistribution` allows for sampling latents from the distribution. | ||
| """ | ||
|
|
||
| latent_dist: torch.FloatTensor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be of type DiagonalGaussianDistribution as the comment says.
The problem is that DiagonalGaussianDistribution is defined later in the file. Solutions:
- Move this down. It would be out of place with respect to the other output dataclasses.
- Declare
class DiagonalGaussianDistribution: passbefore, and comment it's a "forward" declaration. - Use
objectas the type, with a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the string version instead as suggested here: https://peps.python.org/pep-0484/#forward-references
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, nice!
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this big PR! In general everything looks good, just left some nit.
My only big comment, same as Pedro's is, should we rename ModelOutput to something else, as this is used for pipelines, and schedulers as well.
src/diffusers/models/vae.py
Outdated
| `DiagonalGaussianDistribution` allows for sampling latents from the distribution. | ||
| """ | ||
|
|
||
| latent_dist: torch.FloatTensor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
…return tuples (#334) * add outputs for models * add for pipelines * finish schedulers * better naming * adapt tests as well * replace dict access with . access * make schedulers works * finish * correct readme * make bcp compatible * up * small fix * finish * more fixes * more fixes * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/vae.py Co-authored-by: Pedro Cuenca <[email protected]> * Adapt model outputs * Apply more suggestions * finish examples * correct Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
…return tuples (huggingface#334) * add outputs for models * add for pipelines * finish schedulers * better naming * adapt tests as well * replace dict access with . access * make schedulers works * finish * correct readme * make bcp compatible * up * small fix * finish * more fixes * more fixes * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/vae.py Co-authored-by: Pedro Cuenca <[email protected]> * Adapt model outputs * Apply more suggestions * finish examples * correct Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
Model Outputs
This PR allows the user to switch between tuple and Dict/Dataclass outputs for
schedulers,modelsandpipelines.Tuples are now returned whenever the
return_dictflag is set toFalseof the respectivestep,forwardand__call__functions.New recommended API:
Having merged this PR, users are recommended to use schedulers as following:
instead of
,models as following:
instead of
, and pipelines:
instead of
Tests
Extensive comparison tests are added for schedulers, models, and fast pipeline tests
🚨 Deprecation Warnings 🚨
In this PR we're deprecating the
"sample"output keyword for pipelines. Pipelines are modality specific and therefore it is more intuitive to callpipe(...).images[0]🚨🚨🚨 Breaking changes 🚨🚨🚨
This PR introduces breaking changes for the following public-facing methods:
VQModel.encode-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changelatents = model.encode(...)tolatents = model.encode(...)[0]orlatents = model.encode(...).latensVQModel.decode-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model.decode(...)tosample = model.decode(...)[0]orsample = model.decode(...).sampleVQModel.forward-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model(...)tosample = model(...)[0]orsample = model(...).sampleAutoencoderKL.encode-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changelatent_dist = model.encode(...)tolatent_dist = model.encode(...)[0]orlatent_dist = model.encode(...).latent_distAutoencoderKL.decode-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model.decode(...)tosample = model.decode(...)[0]orsample = model.decode(...).sampleAutoencoderKL.forward-> we return a dict/dataclass instead of a single tensor. In the future it's very likely required to return more than just one tensor. Please make sure to changesample = model(...)tosample = model(...)[0]orsample = model(...).sampleTODOS once PR is merged