-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
Context
I noticed in the unit test case test_dataloaders_reset_and_attach in test_dataloaders.py that trainer.fit() was called twice with different train_dataloaders. (code pointer)
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
# 1st fit
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
# 2nd fit
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_3.datasetThe original test case succeed under the current implementation that train/val_dataloader will be reattached before fit_loop.run() is called (code pointer)
# reload data when needed
model = self.lightning_module
self.reset_train_val_dataloaders(model)
self.fit_loop.trainer = self
self.fit_loop.run()However, the second fit_loop could never properly run, because the first fit_loop could property stop when either max_epochs or max_steps are reached, and meanwhile fit_loop.done = True, which leads to fit_loop.skip = True (code pointer). Without initializing a new trainer, the second fit_loop run is just skipped (code pointer).
Discussion:
Do we allow user to start multiple trainer.fit() with train_dataloaders? I understand the needs to have trainer.validate()/test()/predict(), but I think the pattern that allowing trainer.fit() multiple times could be problem. One edge case I can think of to call trainer.fit() multiple times is that trainer.fit() is interrupted by early stopping condition and resumed fit again with different training data. At least, we need to document this or add warning so that users could be aware of fit_loop actually did not start.
Pitch
- Update this test case by removing the second fit call
- Document the change
Environment
- PyTorch Lightning Version (e.g., 1.3.0):
- PyTorch Version (e.g., 1.8)
- Python version:
- OS (e.g., Linux):
- CUDA/cuDNN version:
- GPU models and configuration:
- How you installed PyTorch (
conda,pip, source): - If compiling from source, the output of
torch.__config__.show(): - Any other relevant information:
Additional context
cc @Borda @rohitgr7 @carmocca @justusschock @ananthsub @ninginthecloud