-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task
Description
🐛 Bug
When using the backbone finetuning callback + learning rate moniter with logging interval = epoch an error is thrown after the backbone is unfrozen.
Please reproduce using the BoringModel
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(True))
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(self.backbone(x))
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
opt = torch.optim.Adam(self.layer.parameters(), lr=0.1)
return [opt], [torch.optim.lr_scheduler.StepLR(opt, step_size=1)]
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=6,
weights_summary=None,
callbacks=[BackboneFinetuning(unfreeze_backbone_at_epoch=4), LearningRateMonitor(logging_interval="epoch")]
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
trainer.test(model, test_dataloaders=test_data)
if __name__ == '__main__':
run()
Error:
lr_monitor.py", line 160, in _extract_lr
self.lrs[name].append(lr)
KeyError: 'lr-Adam/pg1'
Exception ignored in: <function tqdm.__del__ at 0x7ff5e83d7d30>
Traceback (most recent call last):
File "/Users/ethan/miniconda3/lib/python3.8/site-packages/tqdm/std.py", line 1090, in __del__
File "/Users/ethan/miniconda3/lib/python3.8/site-packages/tqdm/std.py", line 1303, in close
File "/Users/ethan/miniconda3/lib/python3.8/site-packages/tqdm/std.py", line 1481, in display
File "/Users/ethan/miniconda3/lib/python3.8/site-packages/tqdm/std.py", line 1093, in __repr__
File "/Users/ethan/miniconda3/lib/python3.8/site-packages/tqdm/std.py", line 1443, in format_dict
TypeError: cannot unpack non-iterable NoneType object
Expected behavior
No error
Environment
PL master branch
jdonzallaz
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task