-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Allow for Trainer.fit() to be called multiple times during iteration based training and continue from the batch the trainer had gotten up to last time. e.g. let's say I want to train 10 steps every time I call fit. I want the first fit call to train on batches 0-9, the second fit call to train on batches 10-19 and so on.
Motivation
There are many reasons why training in terms of steps may be better suited to one's problem (some of which outlined here: #7629). However there is no easy way to train in terms of steps and continue training from the batch that was previously reached. To continue the same example as before, using a new trainer or modifying the max_steps on the original trainer simply keeps training on the same first 10 batches which is obviously not desirable.
I have a naive, cumbersome solution which involves creating a new trainer in between every fit call with max_steps specified as n + 10 and some logic placed in the training_step method to skip the batch (return None) for the first n steps, where n is however many cumulative steps was previously reached on the last fit call. However, aside from being cumbersome, this is also inefficient as it means loading every single batch until we get to the one that we want.
Pitch
I want to be able to call Trainer.fit() multiple times with each call training successive n batches in the training dataloader without having to write this logic in the training_step myself. Ideally, this also shouldn't mean loading every batch until we get to the one that we want, the trainer should remember the index it had gotten up to in the dataset/dataloader and just continue from there. Related to #11425
cc @Borda @carmocca @justusschock @awaelchli @ninginthecloud