-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Milestone
Description
🐛 Bug
- both
validation_stepandvalidation_epoch_endare defined through overriding - when the
validation_stepfunction returns a typedefaultdict,TypeError: first argument must be callable or Noneoccurs. - returning types in
List, Dict, OrderedDictdo not give such error. - Error occurs from
pytorch-lightning>=1.4.6 - Error caused by PR Move tracking epoch end outputs logic to the
EvaluationEpochLoop#9261 - Left a comment in Move tracking epoch end outputs logic to the
EvaluationEpochLoop#9261 (comment)
To Reproduce
import os
from collections import defaultdict
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
return defaultdict(float)
def validation_epoch_end(self, outputs):
return outputs
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()Expected behavior
File "scripts/train.py", line 103, in fine_tune
trainer.fit(model, dataloaders["train"], dataloaders["dev"])
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
self._run(model)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 922, in _run
self._dispatch()
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _dispatch
self.accelerator.start_training(self)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1000, in run_stage
return self._run_train()
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_train
self._run_sanity_check(self.lightning_module)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1122, in _run_sanity_check
self._evaluation_loop.run()
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 126, in advance
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/utilities/memory.py", line 44, in recursive_detach
return apply_to_collection(in_dict, torch.Tensor, detach_and_move, to_cpu=to_cpu)
File "/root/.local/share/virtualenvs/zero/lib/python3.8/site-packages/pytorch_lightning/utilities/apply_func.py", line 109, in apply_to_collection
return elem_type(OrderedDict(out))
TypeError: first argument must be callable or NoneEnvironment
* CUDA:
- GPU:
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- available: True
- version: 11.1
* Packages:
- numpy: 1.21.2
- pyTorch_debug: False
- pyTorch_version: 1.9.1+cu111
- pytorch-lightning: 1.4.9
- tqdm: 4.62.3
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.6
- version: #40~20.04.1-Ubuntu SMP Wed Jan 6 10:15:55 UTC 2021Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on