Skip to content

Trainer.fit() remember batch_idx previously reached #11426

@amin-nejad

Description

@amin-nejad

🚀 Feature

Allow for Trainer.fit() to be called multiple times during iteration based training and continue from the batch the trainer had gotten up to last time. e.g. let's say I want to train 10 steps every time I call fit. I want the first fit call to train on batches 0-9, the second fit call to train on batches 10-19 and so on.

Motivation

There are many reasons why training in terms of steps may be better suited to one's problem (some of which outlined here: #7629). However there is no easy way to train in terms of steps and continue training from the batch that was previously reached. To continue the same example as before, using a new trainer or modifying the max_steps on the original trainer simply keeps training on the same first 10 batches which is obviously not desirable.

I have a naive, cumbersome solution which involves creating a new trainer in between every fit call with max_steps specified as n + 10 and some logic placed in the training_step method to skip the batch (return None) for the first n steps, where n is however many cumulative steps was previously reached on the last fit call. However, aside from being cumbersome, this is also inefficient as it means loading every single batch until we get to the one that we want.

Pitch

I want to be able to call Trainer.fit() multiple times with each call training successive n batches in the training dataloader without having to write this logic in the training_step myself. Ideally, this also shouldn't mean loading every batch until we get to the one that we want, the trainer should remember the index it had gotten up to in the dataset/dataloader and just continue from there. Related to #11425

cc @Borda @carmocca @justusschock @awaelchli @ninginthecloud

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions