Skip to content

Trainer.fit() multiple times #9636

@ninginthecloud

Description

@ninginthecloud

🐛 Bug

Context

I noticed in the unit test case test_dataloaders_reset_and_attach in test_dataloaders.py that trainer.fit() was called twice with different train_dataloaders. (code pointer)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

    # 1st fit
    trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
    assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
    assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
    # 2nd fit
    trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
    assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
    assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset

The original test case succeed under the current implementation that train/val_dataloader will be reattached before fit_loop.run() is called (code pointer)

# reload data when needed
model = self.lightning_module

self.reset_train_val_dataloaders(model)

self.fit_loop.trainer = self
self.fit_loop.run()

However, the second fit_loop could never properly run, because the first fit_loop could property stop when either max_epochs or max_steps are reached, and meanwhile fit_loop.done = True, which leads to fit_loop.skip = True (code pointer). Without initializing a new trainer, the second fit_loop run is just skipped (code pointer).

Discussion:

Do we allow user to start multiple trainer.fit() with train_dataloaders? I understand the needs to have trainer.validate()/test()/predict(), but I think the pattern that allowing trainer.fit() multiple times could be problem. One edge case I can think of to call trainer.fit() multiple times is that trainer.fit() is interrupted by early stopping condition and resumed fit again with different training data. At least, we need to document this or add warning so that users could be aware of fit_loop actually did not start.

Pitch

  • Update this test case by removing the second fit call
  • Document the change

Environment

  • PyTorch Lightning Version (e.g., 1.3.0):
  • PyTorch Version (e.g., 1.8)
  • Python version:
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

cc @Borda @rohitgr7 @carmocca @justusschock @ananthsub @ninginthecloud

Metadata

Metadata

Assignees

No one assigned

    Labels

    docsDocumentation relatedloopsRelated to the Loop API

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions