Skip to content

For versions >0.8.2 learning rate is zero for last epoch (potentially a logging bug) #2480

@HHousen

Description

@HHousen

🐛 Bug

Version 0.8.2 and above changed the behavior of either my learning rate scheduler or the WandbLogger logger. I am using a linear warmup and decay scheduler. However, the learning rate graph produced by the LearningRateLogger is as shown below ever since version 0.8.2:

image1

The period where the learning rate is zero corresponds to the last epoch of training as you can see below:

image2

This graph raises another issue. The first epoch appears to take twice as many steps as the second and third epoch. I specified max_epochs=3. During training, each epoch takes the same amount of time, so this seems like a logging issue.

Note that the above graphs are for a model that had its training stopped early. So the last epoch is slightly shorter than the second to last. This is not the issue.

Both of these issues (the 0 learning rate and the twice-as-long epoch) do not exist in version 0.8.1, and both graphs look as they should.

These issues could be caused by the logger or they might actually occur and be logged correctly. I have looked through the changelog and I am guessing that these bugs are caused by "Changed epoch indexing from 0 instead of 1" (#2289). I also may be relying on the fact that epoch indexing started at 1 somewhere in my code, but I do not believe this to be the case.

To Reproduce

Reproducing this problem may be difficult since I can't provide the script and data I used. I used the WandbLogger logger and LearningRateLogger callback. I trained with 1400 warmup steps and accumulate_grad_batches set to 2.

I can provide additional code samples or information that you may need.

Code sample

def lr_lambda_func(current_step, num_warmup_steps, num_training_steps):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(
        0.0,
        float(num_training_steps - current_step)
        / float(max(1, num_training_steps - num_warmup_steps)),
    )

t_total = int(len(self.train_dataloader_object) * self.hparams.max_epochs // self.hparams.accumulate_grad_batches)

lr_lambda = partial(
    lr_lambda_func,
    num_warmup_steps=self.hparams.warmup_steps
    * self.hparams.accumulate_grad_batches,
    num_training_steps=t_total,
)

scheduler = LambdaLR(optimizer, lr_lambda, -1)
scheduler_dict = {"scheduler": scheduler, "interval": "step"}
return ([optimizer], [scheduler_dict])

Expected behavior

The learning rate should warmup and decay in versions greater than 0.8.2 the same way it does in versions less than 0.8.2. Each epoch should be the same number of steps.

The below graphs highlight the expected behavior. They are from a different model so they are not directly comparable, but their shape is as expected since they were captured from a model trained with pytorch_lightning version 0.8.1.

image3

image4

Environment

  • CUDA:
    • GPU:
      • Tesla P100-PCIE-16GB
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: False
    • pyTorch_version: 1.5.1+cu101
    • pytorch-lightning: 0.8.4
    • tensorboard: 2.2.2
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Wed Feb 19 05:26:34 PST 2020

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions