-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
❓ Questions and Help
What is your question?
The documentation for manual optimisation is vague and doesn't provide any complete examples of how to use it correctly. Hence, I need to know what Lightning takes care off, and what I have to do within the training loop.
To be more specific, do I need to zero gradients and call step on the learning rate scheduler myself?
Given that this is "manual" mode I'd be fine having to do it (and half expect I will), but what's extremely confusing is that the given examples seem to switch between stating/showing gradients being manually zeroed and not being touched at all...
Take the current optimizer section example, it does not show anything being zeroed. On the other hand, the documentation for trainer's manual optimisation shows the gradients being explicitly zeroed.
So which is it?
Furthermore, how do I access/step for the learning rate scheduler (or is that not something for me to handle here)?
What have you tried?
For a little context, what I'm trying to do is to port regular PyTorch GAN code into Lightning.
The module dynamically selects whether to train the generator or discriminator at the start of each batch depending on the discriminator's loss.
So if the loss is below some threshold it'll backpropagate & optimise for the discriminator, otherwise generator.
I previously in <=1.0.8 used automatic optimisation with a custom optimizer step function, however, that in all honesty was quite clunky and no longer works with accumulate_grad_batches (which we need as we're working with extremely large 3d data).
Instead of this today I've written code to check whether this batch is for training discriminator or generator and based on that run self.manual_backward(loss, optimizer); optimizer.step().
I'm pleased that it runs, but can't seem to see any documentation which actually specifies whether this is enough to use the scheduler and accumulated gradient batches.
What's your environment?
- OS: Windows Subsystem for Linux (Ubuntu)
- Packaging: pip installed into conda environment
- Version: 1.1.0
Thanks so much for any help!