Skip to content

Continual/Multitask/Transfer Learning in PyTorch Lightning #5314

@imirzadeh

Description

@imirzadeh

🚀 Feature

Modify the Trainer API or add a new API to support multi-stage/phase training for continual learning, multitask learning, and transfer learning.

Motivation

I believe the current assumption in PL is that we have one training dataset, and the fit() method should be called once.

This corresponds to the isolated (single-stage) learning paradigm, while a more general case is sequential (multi-stage) learning that works with the assumption that we have multiple tasks to learn.

An example case is continual learning. Assume we have 3 tasks; each is to learn a dataset (D_1, D_2, D_3). Note that we can't concat datasets because we want to measure the performance after ending each task. Also, each task might use a different number of epochs, different optimizers, etc...

In PyTorch, we can write a function to do this and call it multiple times.

def train_model_on_dataset(model, data_loader, num_epochs):
  # default code for training / optimizer.step ....
  
model = Model()
# train on task 1, 2, 3 sequentially 
train_model_on_dataset(model, D_1, 5)
train_model_on_dataset(model, D_2, 5)
train_model_on_dataset(model, D_3, 5)

So ideally, .fit() should correspond to the train_model_on_dataset() but then it can't handle the number of epochs. To do this in PL, we should create new trainers:

trainer = Trainer(max_epochs=5)
trainer.fit(model1, D_1)
trainer = Trainer(max_epochs=5)
trainer.fit(model1, D_2)
...

But again, the second trainer is a completely different object and has its step/epoch counters reset. For example, when logging metrics, the epoch starts from 0 for D_2 while actually, it should be 5.

Pitch

Proposal: Add an orchestrator wrapper for Trainer

Create a MetaTrainer object for multi-stage training. To make the API consistent with the current functionality, we keep a stage (phase/time) variable. Everything the current trainer does assume that it happens in stage = 0 (single task/single-stage/single phase).

Class MetaTrainer:
      def __init__(trainer: Trainer, ...): 
          self.stage = 0
          self.trainers = [trainer, ]
          # additional book-keeping variables
          
      def fit(...):
           self.trainers[stage].fit(...)
       
      # go to next stage/phase
      def tick():
           self.stage += 1
           next_trainer = Trainer() # we can also [optionally] pass it to trainer for more flexibility
           self.trainers.append(next_trainer)
     
      # additionally, the `MetaTrainer` can also support special callbacks at each stage's start/end.
      def on_stage_starts():
            ...

      def on_stage_ends():
            ...
     # ... re-implement other methods of Trainer (e.g., log, checkpoint saving) to support the concept of the stage...

Applications

Transfer Learning

stage_callbacks = {1: [FreezeExceptClassifier()]}
# traininig
meta_trainer = MetaTrainer(trainer, stage_callbacks)
meta_trainer.fit(loader_train)

# fine tuning
meta_trainer.tick()
meta_trainer.fit(loader_finetune)

Continual (Lifelong) and Multitask Learning

meta_trainer = MetaTrainer(trainer)
# task 1
meta_trainer.fit(loader_1)
# task 2
meta_trainer.tick()
meta_trainer.fit(loader_2)

The difference here is that the meta_trainer object has control over the whole training session and is aware of the current training stage.

Discussion

I may be wrong, and the coder should implement the MetaTrainer. Maybe even there's a better solution for these situations that I'm not aware of. Anyway, I just wanted to share my ideas here.

Additional context

Related issues:
#2758
#2006

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions