-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Add DPT Flax #17779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DPT Flax #17779
Conversation
- added DPT in Flax - all non slow tests passes in local - still some nits have to be investigated
…e' object has no attribute 'tolist'`
- BN seems to work now - Equivalency test pass with tol=1e-4 but only with a hack
|
All the keys match now but the equivalency test does not pass with |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing predicted depth documentation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
| class FlaxDPTUpsample(nn.Module): | ||
| scale: int = 2 | ||
| method: str = "bilinear" | ||
|
|
||
| def setup(self): | ||
| pass | ||
|
|
||
| def __call__(self, x, output_size=None): | ||
| if output_size is None: | ||
| output_size = x.shape | ||
| output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3]) | ||
| return jax.image.resize(x, output_size, method="bilinear") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should support align_corners = True
|
Would be great to also incorporate the updates of #17731 |
|
|
||
| def setup(self): | ||
| self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) | ||
| self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NielsRogge I think that in this implementation we directly initialize the modules with the config (contrary than in DPTViTEmbeddings) if this is what you meant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied these modules from FlaxViTModel which seem to have the right structure as suggested in #17731
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, then it's alright :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh but the sanity check of the channel size is missing indeed, will add that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we don't have it for other Flax models as well right now. Ideally (and also for consistency), we should have it for all vision models
|
The Flax model finally predicts the correct depths for the cats (left is Flax and right is Pytorch)! |
ccfb913 to
a379a2f
Compare
…to dpt-flax-younes
a379a2f to
7762a1e
Compare
|
As wee discussed, it seems that |
|
@ArthurZucker exact. I have put a new attribute in the |
- added new attribute to config without breaking backward compatibility - modified a bit the tests
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @younesbelkada! Looks pretty good from the Flax side of things! Left a few requests, the overarching one being the use of # Copied from... statements, both for internal Transformers code and that copied externally (e.g. from Haiku). Really helps in knowing what are the salient portions of the modelling code to review! But otherwise a very strong effort on the Flax front 💪
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>
- removed `copied_from` on non module objects - check why `FlaxDPTViTLayerCollection` is not copied from `FlaxViTLayerCollection`
- added correct link for `CopiedFrom` - Added explicit argument for transposed conv on model def
|
Thank you very much @sanchit-gandhi for the very detailed review! I had a second round of refactoring while catching up on Flax projects and would love to have a second round of review (left also some unresolved comments) 💪 Thanks again 🙏 |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |

What does this PR do?
I tried to implement DPT (Dense Prediction with Transformers) in Flax during my free time! 🚀
By the way it is the first Segmentation and Depth Estimation model implemented in Flax on the library!
Nits/TODOs:
BatchNormandDropoutinside aSequentialSequentiallayersQuetions:
modeling_dpt.py? I can probably help on that since I have already implemented the loss for a university project: https://github.com/antocad/FocusOnDepth/blob/master/FOD/Loss.pycc @NielsRogge @sanchit-gandhi @patil-suraj