Skip to content

Warmstart with FSDP2 and weight tying fails #381

@flxst

Description

@flxst

System Info

modalities version: bd649de (main branch after merging #379)
platform: DGX / Linux
python version: 3.11.11

🐛 Describe the bug

The warmstart tutorial fails:

# alternative 1
cd tutorial/warmstart/scripts
bash pre_train_and_warmstart.sh 6 7

# alternative 2
python tests/tests.py -e -d 6,7

The error message is

KeyError: 'state.transformer.lm_head.weight.step'

Note that warmstarting using the configs in config_files and the unit tests tests/end2end_tests/test_fsdp_warmstart.py work. They all use weight_tying = False in combination with FSDP1 or FSDP2. The warmstart tutorial, in contrast, uses weight_tying = True in conjunction with FSDP2. In fact, the warmstart tutorial also works if weight tying is set to False. On the other hand, warmstarting with weight tying = True and FSDP1 works as well. So it is the combination of FSDP2 and weight tying that seems to cause the problem.

Overview:

  • FSDP1 & weight_tying = False: works
  • FSDP1 & weight_tying = True: works
  • FSDP2 & weight_tying = False: works
  • FSDP2 & weight_tying = True: fails

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions