-
Notifications
You must be signed in to change notification settings - Fork 12
Added script for updating old checkpoints and configs. #397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a utility script for migrating old model checkpoints and configuration files to a new format. The script handles both configuration file updates and checkpoint state dictionary transformations to maintain compatibility with updated model structures.
Key Changes
- Added comprehensive checkpoint and config migration script with YAML processing
- Implemented state dictionary updates for model weight key transformations
- Added validation functionality to test updated configurations
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
…ariables. Co-authored-by: Copilot <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit hesitating if the automated config updates are the way to go, or if we should provide documentation e.g., a diff for model_raw
, explaining how to update the models.
The reason is that we are updating now based on existing component names, e.g., checkpointed_model
. However, the configs themselves never enforce certain component names, which is why if the user renames checkpointed_model
to something like my_checkpointed_model
, then the conversion script already fails.
Also, we are deleting some components, which are still used if you create the diff between these two configs:
Nevertheless, I think that the automated checkpoint update is still useful and I would place it in a backward_compatibility/
module.
old_model_config = sys.argv[1] | ||
new_model_config = sys.argv[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are paths not "configs"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use pathlib.Path
config_type = dict[str, "str | config_type"] | ||
|
||
|
||
def update_model(old_model_config: str, new_model_config: str, new_checkpoint_path: str | None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arguments are all paths.
|
||
def add_new_keys(config: config_type): | ||
model_config = config["model_raw" if "model_raw" in config else "model"]["config"] | ||
model_config["use_weight_tying"] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight tying we also had before. Why are we hardcoding this to False now?
if "evaluation_subscriber" in config and "experiment_id" in config["evaluation_subscriber"]["config"]: | ||
del config["evaluation_subscriber"]["config"]["experiment_id"] | ||
if "settings" in config and "experiment_id" in config["settings"]: | ||
del config["settings"]["experiment_id"] | ||
if ( | ||
"checkpoint_saving" in config | ||
and "checkpoint_saving_execution" in config["checkpoint_saving"]["config"] | ||
and "experiment_id" in config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"] | ||
): | ||
del config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"]["experiment_id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we deleting this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in the current FSDP2 config e.g., https://github.com/Modalities/modalities/blob/83c87b9d6d6fbbb228bab31dccf1870b12679775/config_files/training/config_lorem_ipsum_long_fsdp2.yaml we still have all of this.
|
||
|
||
def rename_keys(config: config_type): | ||
model_config = config["model_raw" if "model_raw" in config else "model"]["config"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could have the convention that general model must be always named model_raw
.
We are already enforcing it here:
model_raw: PydanticPytorchModuleType |
new_model_config = sys.argv[2] | ||
new_checkpoint_path = sys.argv[3] if len(sys.argv) > 3 else None | ||
|
||
update_model(old_model_config, new_model_config, new_checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make updating checkpoint and updating the config two separate functions that get called sequentially here.
old_norm_keys = ["attention_norm", "ffn_norm", "lm_head_norm"] | ||
new_norm_keys = ["attention_norm_config", "ffn_norm_config", "lm_head_norm_config"] | ||
for old_key, new_key in zip(old_norm_keys, new_norm_keys): | ||
rename_config_key(model_config, old_key, new_key) | ||
rename_config_key(model_config[new_key], "variant_key", "norm_type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should delete component_key
, no?
if new_checkpoint_path is not None: | ||
if "checkpointed_model" in config: | ||
old_path = config["checkpointed_model"]["config"]["checkpoint_path"] | ||
config["checkpointed_model"]["config"]["checkpoint_path"] = new_checkpoint_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked all configs, where did you see checkpointed_model
?
""" | ||
state_dict = torch.load(old_model_path) | ||
if "lm_head.weight" in state_dict: | ||
state_dict["transformer.lm_head.weight"] = state_dict["lm_head.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would this behave, if we used weight tying?
Do we store them twice (i.e., embeddings and lm_head) and then internally replace the lm_head with a reference to the embeddings weights?
What does this PR do?
This PR adds a script for updating old checkpoints and configs.
General Changes
Checklist before submitting final PR
python tests/tests.py
)CHANGELOG_DEV.md
)