Skip to content

Conversation

kachayev
Copy link
Collaborator

Types of changes

The code sample from the documentation page "Wasserstein 2 Minibatch GAN" has been updated to zero out the gradient after each batch.

Motivation and context / Related issue

The issue was described in #464. I also reviewed other sections in the documentation where the PyTorch optimizer is used, and found that all other code samples properly call zero_grad().

I experimented with various optimizers and settings, such as SGD, but the difference in wall time was negligible. I'm not sure if it's worth updating for performance reasons alone. @rflamary let me know what do you think.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@kachayev kachayev changed the title [MRG] Fix gradient update for "Wasserstein 2 Minibatch GAN" example [DOC] Fix gradient update for "Wasserstein 2 Minibatch GAN" example Apr 29, 2023
@rflamary rflamary merged commit 8a7035b into PythonOT:master May 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants