Skip to content

Step optimizers at arbitrary intervals #7405

@celsofranssa

Description

@celsofranssa

🐛 Bug

In my model I have two optimizers that alternate during the optimize steps as shown below:

class PLModel(LightningModule):

    def __init__(self, hparams):

        super(PLModel, self).__init__()
        self.hparams = hparams
        
        ...


    def configure_optimizers(self):

        # optimizers
        optimizers = [
            torch.optim.AdamW(self.x1_encoder.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-08,
                             weight_decay=self.hparams.weight_decay, amsgrad=True),
            torch.optim.AdamW(self.x2_encoder.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-08,
                             weight_decay=self.hparams.weight_decay, amsgrad=True)
        ]

        # schedulers
        step_size_up = 0.03 * self.num_training_steps

        schedulers = [
            torch.optim.lr_scheduler.CyclicLR(
                optimizers[0],
                mode='triangular2',
                base_lr=self.hparams.base_lr,
                max_lr=self.hparams.max_lr,
                step_size_up=step_size_up,
                cycle_momentum=False),
            torch.optim.lr_scheduler.CyclicLR(
                optimizers[1],
                mode='triangular2',
                base_lr=self.hparams.base_lr,
                max_lr=self.hparams.max_lr,
                step_size_up=step_size_up,
                cycle_momentum=False)
        ]
        return optimizers, schedulers

    # Alternating schedule for optimizer steps
    def optimizer_step(
            self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
            on_tpu, using_native_amp, using_lbfgs
    ):
        # update x1_encoder opt
        if optimizer_idx == 0:
            if batch_idx % 2 == 0:
                optimizer.step(closure=optimizer_closure)

        # update x2_encoder opt
        if optimizer_idx == 1:
            if batch_idx % 2 != 0:
                optimizer.step(closure=optimizer_closure)

But I am getting the following error message during the optimizer step:

Epoch 0:   0% 0/4086 [00:07<?, ?it/s]
Traceback (most recent call last):
  File "main.py", line 213, in perform_tasks
    fit(hparams)
  File "main.py", line 89, in fit
    trainer.fit(model, datamodule=dm)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 499, in fit
    self.dispatch()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 546, in dispatch
    self.accelerator.start_training(self)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py", line 73, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 114, in start_training
    self._results = trainer.run_train()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 637, in run_train
    self.train_loop.run_training_epoch()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py", line 492, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py", line 654, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py", line 433, in optimizer_step
    using_lbfgs=is_lbfgs,
  File "/content/source/model/PLModel.py", line 82, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py", line 274, in optimizer_step
    self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/native_amp.py", line 78, in pre_optimizer_step
    lambda_closure()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py", line 649, in train_step_and_backward_closure
    split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py", line 755, in training_step_and_backward
    self.backward(result, optimizer, opt_idx)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py", line 785, in backward
    result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py", line 257, in backward
    self.lightning_module, closure_loss, optimizer, optimizer_idx, should_accumulate, *args, **kwargs
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/native_amp.py", line 59, in backward
    closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 71, in backward
    model.backward(closure_loss, optimizer, opt_idx)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py", line 1251, in backward
    loss.backward(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Could anyone tell me what I'm missing?

To Reproduce

This issue could be reproduced using BoringModel on Google Colab.

Expected behavior

Without override optimizer_step method, training happens normally. Therefore, applying step optimizers at arbitrary intervals theoretically should occur normally.

OBS: when using pytorch-lightning==1.1.3, optimizer_step was running smoothly

Environment

* CUDA:
	- GPU:
		- Tesla P100-PCIE-16GB
	- available:         True
	- version:           10.1
* Packages:
	- numpy:             1.19.5
	- pyTorch_debug:     False
	- pyTorch_version:   1.8.1+cu101
	- pytorch-lightning: 1.2.10
	- tqdm:              4.41.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- 
	- processor:         x86_64
	- python:            3.7.10
	- version:           #1 SMP Thu Jul 23 08:00:38 PDT 2020

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onpriority: 2Low priority taskwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions