Skip to content

Conversation

BlueCrescent
Copy link
Collaborator

What does this PR do?

This PR adds a script for updating old checkpoints and configs.

General Changes

  • Added the script.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

@BlueCrescent BlueCrescent requested a review from Copilot August 25, 2025 11:53
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a utility script for migrating old model checkpoints and configuration files to a new format. The script handles both configuration file updates and checkpoint state dictionary transformations to maintain compatibility with updated model structures.

Key Changes

  • Added comprehensive checkpoint and config migration script with YAML processing
  • Implemented state dictionary updates for model weight key transformations
  • Added validation functionality to test updated configurations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit hesitating if the automated config updates are the way to go, or if we should provide documentation e.g., a diff for model_raw, explaining how to update the models.
The reason is that we are updating now based on existing component names, e.g., checkpointed_model. However, the configs themselves never enforce certain component names, which is why if the user renames checkpointed_model to something like my_checkpointed_model, then the conversion script already fails.
Also, we are deleting some components, which are still used if you create the diff between these two configs:

https://github.com/Modalities/modalities/blob/83c87b9d6d6fbbb228bab31dccf1870b12679775/config_files/training/config_lorem_ipsum_long_fsdp1.yaml

https://github.com/Modalities/modalities/blob/83c87b9d6d6fbbb228bab31dccf1870b12679775/config_files/training/config_lorem_ipsum_long_fsdp2.yaml

Nevertheless, I think that the automated checkpoint update is still useful and I would place it in a backward_compatibility/ module.

Comment on lines +152 to +153
old_model_config = sys.argv[1]
new_model_config = sys.argv[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are paths not "configs"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use pathlib.Path

config_type = dict[str, "str | config_type"]


def update_model(old_model_config: str, new_model_config: str, new_checkpoint_path: str | None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arguments are all paths.


def add_new_keys(config: config_type):
model_config = config["model_raw" if "model_raw" in config else "model"]["config"]
model_config["use_weight_tying"] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight tying we also had before. Why are we hardcoding this to False now?

Comment on lines +92 to +101
if "evaluation_subscriber" in config and "experiment_id" in config["evaluation_subscriber"]["config"]:
del config["evaluation_subscriber"]["config"]["experiment_id"]
if "settings" in config and "experiment_id" in config["settings"]:
del config["settings"]["experiment_id"]
if (
"checkpoint_saving" in config
and "checkpoint_saving_execution" in config["checkpoint_saving"]["config"]
and "experiment_id" in config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"]
):
del config["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"]["experiment_id"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we deleting this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def rename_keys(config: config_type):
model_config = config["model_raw" if "model_raw" in config else "model"]["config"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could have the convention that general model must be always named model_raw.
We are already enforcing it here:

model_raw: PydanticPytorchModuleType

new_model_config = sys.argv[2]
new_checkpoint_path = sys.argv[3] if len(sys.argv) > 3 else None

update_model(old_model_config, new_model_config, new_checkpoint_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make updating checkpoint and updating the config two separate functions that get called sequentially here.

Comment on lines +63 to +67
old_norm_keys = ["attention_norm", "ffn_norm", "lm_head_norm"]
new_norm_keys = ["attention_norm_config", "ffn_norm_config", "lm_head_norm_config"]
for old_key, new_key in zip(old_norm_keys, new_norm_keys):
rename_config_key(model_config, old_key, new_key)
rename_config_key(model_config[new_key], "variant_key", "norm_type")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should delete component_key, no?

if new_checkpoint_path is not None:
if "checkpointed_model" in config:
old_path = config["checkpointed_model"]["config"]["checkpoint_path"]
config["checkpointed_model"]["config"]["checkpoint_path"] = new_checkpoint_path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked all configs, where did you see checkpointed_model?

"""
state_dict = torch.load(old_model_path)
if "lm_head.weight" in state_dict:
state_dict["transformer.lm_head.weight"] = state_dict["lm_head.weight"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would this behave, if we used weight tying?
Do we store them twice (i.e., embeddings and lm_head) and then internally replace the lm_head with a reference to the embeddings weights?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants