Skip to content

Conversation

@PhilJd
Copy link
Contributor

@PhilJd PhilJd commented Apr 11, 2019

This PR ports the decoupled weight decay optimizers (SGDW, AdamW) to tensorflow 2.0, minus all v1 tests, as tensorflow_addons depends on tf-2.0 anyway.
Note that I factored out the testing code that is duplicated in most optimizer tests (also in base tensorflow) into optimizer_test_base. After this PR has been merged, I'd adapt the LazyAdam optimizer to inherit from OptimizerTestBase.

Sorry for the long delay!
Cheers,
Phil :)

Closes #24.

@PhilJd PhilJd requested a review from a team as a code owner April 11, 2019 12:54
@seanpmorgan seanpmorgan changed the title Philjd/decaoupled weight decay Decoupled weight decay Apr 13, 2019
@facaiy facaiy self-assigned this Apr 15, 2019
@facaiy facaiy requested review from a team, facaiy and seanpmorgan April 15, 2019 02:17
@facaiy
Copy link
Member

facaiy commented Apr 15, 2019

@PhilJd Welcome, thanks for the PR, and your patience. I'll take a look later this week :-)

@mfojtak
Copy link

mfojtak commented Apr 16, 2019

You probably need to add this

from tensorflow_addons.optimizers.weight_decay_optimizers import AdamWOptimizer

into tensorflow_addons/optimizers/__init__.py in order to expose those classes

@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 17, 2019

Thanks! I've added the respective classes and functions to __init__.py :)

Copy link
Member

@facaiy facaiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, I'll take another look at weekend :-)

self.evaluate(repeated_index_update_var))


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need it?

@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 18, 2019

Thanks for the comments @facaiy, I've updated the PR!

Copy link
Member

@facaiy facaiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Leave some questions for tf 2.0

@facaiy
Copy link
Member

facaiy commented Apr 24, 2019

@PhilJd Phil, don't forget to run make code-format :-)

@facaiy
Copy link
Member

facaiy commented Apr 25, 2019

@PhilJd Hi, Phil. Could you address all comments? Thank you for the high quality PR, can't wait to merge it :-)

…izer tests, optimizer params are now keywords instead of a dict. Fix code in comments to support tf-2.0, naming errors, line length.
@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 25, 2019

@facaiy Thanks a lot for the comments! :)
I've addressed them all, except for the tests. I think someone put testing on the agenda for the next SIG call.
On top of that:

  • I found additional documentation bugs, code samples were not compatible with tf-2.0.
  • The last commit also fixes extend_with_weight_decay function doesn't exist?  tensorflow#26360.
  • I've renamed AdamWOptimizer/SGDWOptimizer to AdamW/SGDW, which follows the keras naming scheme instead of tf-1.0 naming scheme
  • AdamW and SGDW now have a note that the decay also needs to be scheduled, originally only the extend function and the base class had this note.
  • I've updated the paper title to the title of the published version

By the way, I ran make code-format before, but it seems this doesn't touch comments ;)
Thanks again for taking a detailed look! I also hope we can merge soon :)

Copy link
Member

@facaiy facaiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very close :-)

@PhilJd PhilJd requested a review from WindQAQ as a code owner April 29, 2019 07:57
@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 29, 2019

Oops, I forgot to commit optimizer_test_base ;)
Thanks for the comments, they should be resolved with the latest commit!

@facaiy
Copy link
Member

facaiy commented Apr 29, 2019

Looks great, thank you, PhilJ! Could you resolve the merge conflict with master branch?

By the way, you can put your name in the contact info list and code owner file if you'd like to maintain the module contributed by yourself.

@facaiy
Copy link
Member

facaiy commented Apr 29, 2019

@seanpmorgan @WindQAQ Sean, Tzu-Wei, do you have any concerns about this change?

facaiy
facaiy previously approved these changes Apr 29, 2019
@seanpmorgan
Copy link
Member

@seanpmorgan @WindQAQ Sean, Tzu-Wei, do you have any concerns about this change?

No, looks like a very nice PR.. Just needs to resolve conflicts and test IMO

WindQAQ
WindQAQ previously approved these changes Apr 30, 2019
@PhilJd PhilJd dismissed stale reviews from WindQAQ and facaiy via f8b6bdf April 30, 2019 07:12
@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 30, 2019

Conflicts with the master are resolved and I've put my name as maintainer into the README ;)

facaiy
facaiy previously approved these changes Apr 30, 2019
@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 30, 2019

Interesting, running make code-format locally doesn't complain.

@facaiy
Copy link
Member

facaiy commented Apr 30, 2019

@PhilJd Phil, could you run make code-format again? And if you'd like, please add the line below

/tensorflow_addons/losses/sparsemax*.py @AndreasMadsen

/tensorflow_addons/optimizers/weight_decay_optimizers*.py  @PhilJd 

@facaiy
Copy link
Member

facaiy commented Apr 30, 2019

Please ping me when you get it done, and I'll merge it. Thanks for your patience :-)

@PhilJd
Copy link
Contributor Author

PhilJd commented Apr 30, 2019

I've applied the patch from the build server. I wasn't able to find out why the code-format doesn't complain locally. I've tried the clang format version the build task uses, wget https://llvm.org/svn/llvm-project/cfe/trunk/tools/clang-format/git-clang-format, but still no changes...

Copy link
Member

@facaiy facaiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement Weight Decay Optimizers

8 participants