-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Improve vision models #17731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve vision models #17731
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
dcd728c to
33b720a
Compare
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice cleanup! Thanks for working on it!
src/transformers/models/data2vec/modeling_tf_data2vec_vision.py
Outdated
Show resolved
Hide resolved
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for making all these changes 🧹🧹🧹
Just some small comments about tests, but otherwise LGTM :)
* Initial TF DeiT implementation * Fix copies naming issues * Fix up + docs * Properly same main layer * Name layers properly * Initial TF DeiT implementation * Fix copies naming issues * Fix up + docs * Properly same main layer * Name layers properly * Fixup * Fix import * Fix import * Fix import * Fix weight loading for tests whilst not on hub * Add doc tests and remove to_2tuple * Add back to_2tuple Removing to_2tuple results in many downstream changes needed because of the copies checks * Incorporate updates in Improve vision models #17731 PR * Don't hard code num_channels * Copy PyTorch DeiT embeddings and remove pytorch operations with mask * Fix patch embeddings & tidy up * Update PixelShuffle to move logic into class layer * Update doc strings - remove PT references * Use NHWC format in internal layers * Fix up * Use linear activation layer * Remove unused import * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * Move dataclass to top of file * Remove from_pt now weights on hub * Fixup Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Amy Roberts <[email protected]>
* Initial TF DeiT implementation * Fix copies naming issues * Fix up + docs * Properly same main layer * Name layers properly * Initial TF DeiT implementation * Fix copies naming issues * Fix up + docs * Properly same main layer * Name layers properly * Fixup * Fix import * Fix import * Fix import * Fix weight loading for tests whilst not on hub * Add doc tests and remove to_2tuple * Add back to_2tuple Removing to_2tuple results in many downstream changes needed because of the copies checks * Incorporate updates in Improve vision models huggingface#17731 PR * Don't hard code num_channels * Copy PyTorch DeiT embeddings and remove pytorch operations with mask * Fix patch embeddings & tidy up * Update PixelShuffle to move logic into class layer * Update doc strings - remove PT references * Use NHWC format in internal layers * Fix up * Use linear activation layer * Remove unused import * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * Move dataclass to top of file * Remove from_pt now weights on hub * Fixup Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Amy Roberts <[email protected]>
What does this PR do?
This PR improves the vision models by:
to_2tupleconfig.num_channelsconfig.num_channelsforxxxForMaskedImageModelingmodels (fixes SimMIM output num_channels should not be hardcoded #17727)config.num_channelsin Flax models (ViT, BEiT)To do: