Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 20, 2022

What does this PR do?

I tried to implement DPT (Dense Prediction with Transformers) in Flax during my free time! 🚀
By the way it is the first Segmentation and Depth Estimation model implemented in Flax on the library!

Nits/TODOs:

  • Figure out how to properly call BatchNorm and Dropout inside a Sequential
  • Deal correctly with Sequential layers
  • Test equivalency tests
  • Write documentation - For now they're just copy/pasted
    Quetions:
  • Why the loss is not implemented in modeling_dpt.py ? I can probably help on that since I have already implemented the loss for a university project: https://github.com/antocad/FocusOnDepth/blob/master/FOD/Loss.py

cc @NielsRogge @sanchit-gandhi @patil-suraj

younesbelkada and others added 5 commits June 18, 2022 22:34
- added DPT in Flax
- all non slow tests passes in local
- still some nits have to be investigated
- BN seems to work now
- Equivalency test pass with tol=1e-4 but only with a hack
@younesbelkada
Copy link
Contributor Author

All the keys match now but the equivalency test does not pass with 1e-5 but 1e-4 instead

Comment on lines 693 to 694
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing predicted depth documentation

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada younesbelkada marked this pull request as ready for review June 20, 2022 10:54
@younesbelkada younesbelkada changed the title Add Flax implementation of DPT [WIP] Add Flax implementation of DPT Jun 20, 2022
Comment on lines 860 to 871
class FlaxDPTUpsample(nn.Module):
scale: int = 2
method: str = "bilinear"

def setup(self):
pass

def __call__(self, x, output_size=None):
if output_size is None:
output_size = x.shape
output_size = (output_size[0], output_size[1] * self.scale, output_size[2] * self.scale, output_size[3])
return jax.image.resize(x, output_size, method="bilinear")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should support align_corners = True

- more documentation
- fix nit
@NielsRogge
Copy link
Contributor

Would be great to also incorporate the updates of #17731


def setup(self):
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype)
Copy link
Contributor Author

@younesbelkada younesbelkada Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NielsRogge I think that in this implementation we directly initialize the modules with the config (contrary than in DPTViTEmbeddings) if this is what you meant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied these modules from FlaxViTModel which seem to have the right structure as suggested in #17731

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then it's alright :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but the sanity check of the channel size is missing indeed, will add that!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we don't have it for other Flax models as well right now. Ideally (and also for consistency), we should have it for all vision models

- add custom conv transpose2d function
- modify test
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 25, 2022

The Flax model finally predicts the correct depths for the cats (left is Flax and right is Pytorch)!
Screenshot 2022-06-25 at 20 12 08
For that it appears that the transpose conv does not give the same result as Pytorch's implementation that uses a gradient based operation. I fixed it by creating a custom function based on this PR: jax-ml/jax#5772 the PR does not seem to be merged soon. We can probably go for this hack for now until the PR in JAX gets merged

@ArthurZucker
Copy link
Collaborator

As wee discussed, it seems that align_corners set to False for both model would not require lowering the tolerance in one of the cases right?

@younesbelkada
Copy link
Contributor Author

@ArthurZucker exact. I have put a new attribute in the DPTConfig and modified a bit the original modeling code but should not break backward compatibility. Now all tests pass with a tolerance of 1e-5

- added new attribute to config without breaking backward compatibility
- modified a bit the tests
@younesbelkada younesbelkada changed the title [WIP] Add Flax implementation of DPT Add Flax implementation of DPT Jun 25, 2022
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @younesbelkada! Looks pretty good from the Flax side of things! Left a few requests, the overarching one being the use of # Copied from... statements, both for internal Transformers code and that copied externally (e.g. from Haiku). Really helps in knowing what are the salient portions of the modelling code to review! But otherwise a very strong effort on the Flax front 💪

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 3, 2022
@younesbelkada younesbelkada reopened this Aug 3, 2022
younesbelkada and others added 7 commits August 10, 2022 15:26
- removed `copied_from` on non module objects
- check why `FlaxDPTViTLayerCollection` is not copied from `FlaxViTLayerCollection`
@younesbelkada younesbelkada changed the title Add Flax implementation of DPT Add DPT Flax Aug 10, 2022
- added correct link for `CopiedFrom`
- Added explicit argument for transposed conv on model def
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Aug 10, 2022

Thank you very much @sanchit-gandhi for the very detailed review! I had a second round of refactoring while catching up on Flax projects and would love to have a second round of review (left also some unresolved comments) 💪 Thanks again 🙏

@github-actions
Copy link
Contributor

github-actions bot commented Oct 7, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Oct 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants