-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Modify the Trainer API or add a new API to support multi-stage/phase training for continual learning, multitask learning, and transfer learning.
Motivation
I believe the current assumption in PL is that we have one training dataset, and the fit() method should be called once.
This corresponds to the isolated (single-stage) learning paradigm, while a more general case is sequential (multi-stage) learning that works with the assumption that we have multiple tasks to learn.
An example case is continual learning. Assume we have 3 tasks; each is to learn a dataset (D_1, D_2, D_3). Note that we can't concat datasets because we want to measure the performance after ending each task. Also, each task might use a different number of epochs, different optimizers, etc...
In PyTorch, we can write a function to do this and call it multiple times.
def train_model_on_dataset(model, data_loader, num_epochs):
# default code for training / optimizer.step ....
model = Model()
# train on task 1, 2, 3 sequentially
train_model_on_dataset(model, D_1, 5)
train_model_on_dataset(model, D_2, 5)
train_model_on_dataset(model, D_3, 5)So ideally, .fit() should correspond to the train_model_on_dataset() but then it can't handle the number of epochs. To do this in PL, we should create new trainers:
trainer = Trainer(max_epochs=5)
trainer.fit(model1, D_1)
trainer = Trainer(max_epochs=5)
trainer.fit(model1, D_2)
...But again, the second trainer is a completely different object and has its step/epoch counters reset. For example, when logging metrics, the epoch starts from 0 for D_2 while actually, it should be 5.
Pitch
Proposal: Add an orchestrator wrapper for Trainer
Create a MetaTrainer object for multi-stage training. To make the API consistent with the current functionality, we keep a stage (phase/time) variable. Everything the current trainer does assume that it happens in stage = 0 (single task/single-stage/single phase).
Class MetaTrainer:
def __init__(trainer: Trainer, ...):
self.stage = 0
self.trainers = [trainer, ]
# additional book-keeping variables
def fit(...):
self.trainers[stage].fit(...)
# go to next stage/phase
def tick():
self.stage += 1
next_trainer = Trainer() # we can also [optionally] pass it to trainer for more flexibility
self.trainers.append(next_trainer)
# additionally, the `MetaTrainer` can also support special callbacks at each stage's start/end.
def on_stage_starts():
...
def on_stage_ends():
...
# ... re-implement other methods of Trainer (e.g., log, checkpoint saving) to support the concept of the stage...Applications
Transfer Learning
stage_callbacks = {1: [FreezeExceptClassifier()]}
# traininig
meta_trainer = MetaTrainer(trainer, stage_callbacks)
meta_trainer.fit(loader_train)
# fine tuning
meta_trainer.tick()
meta_trainer.fit(loader_finetune)Continual (Lifelong) and Multitask Learning
meta_trainer = MetaTrainer(trainer)
# task 1
meta_trainer.fit(loader_1)
# task 2
meta_trainer.tick()
meta_trainer.fit(loader_2)The difference here is that the meta_trainer object has control over the whole training session and is aware of the current training stage.
Discussion
I may be wrong, and the coder should implement the MetaTrainer. Maybe even there's a better solution for these situations that I'm not aware of. Anyway, I just wanted to share my ideas here.