From c952b018c1ed27c4fc932bc0144b3ae564981254 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 Apr 2024 11:32:04 +0200 Subject: [PATCH 1/9] Add options to freeze the representation model and reset the output weights when loading an already trained model for training. --- torchmdnet/module.py | 5 +++++ torchmdnet/scripts/train.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915e..a34ebf2cd 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -70,6 +70,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): if self.hparams.load_model: self.model = load_model(self.hparams.load_model, args=self.hparams) + if self.hparams.freeze_representation: + for p in self.model.representation_model.parameters(): + p.requires_grad = False + if self.hparams.reset_output_model: + self.model.output_model.reset_parameters() else: self.model = create_model(self.hparams, prior_model, mean, std) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 2e69212b4..dd53c9a57 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -121,7 +121,8 @@ def get_argparse(): parser.add_argument('--wandb-project', default='training_', type=str, help='Define what wandb Project to log to') parser.add_argument('--wandb-resume-from-id', default=None, type=str, help='Resume a wandb run from a given run id. The id can be retrieved from the wandb dashboard') parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') - + parser.add_argument('--freeze_representation', type=bool, default=False, help='Freeze the representation model parameters during training. This option is only used if the training is not starting from scratch.') + parser.add_argument('--reset_output_model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.') # fmt: on return parser From 160e3c4ac54419b3fbc5d10bf0ef8a92ce11d31a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 Apr 2024 11:46:47 +0200 Subject: [PATCH 2/9] Add the overwrite-representation option, which allows to ovewrite the representation model weights using a checkpoint. --- torchmdnet/module.py | 13 ++++++++++--- torchmdnet/scripts/train.py | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index a34ebf2cd..2cc7293d8 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -70,14 +70,21 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): if self.hparams.load_model: self.model = load_model(self.hparams.load_model, args=self.hparams) - if self.hparams.freeze_representation: - for p in self.model.representation_model.parameters(): - p.requires_grad = False if self.hparams.reset_output_model: self.model.output_model.reset_parameters() else: self.model = create_model(self.hparams, prior_model, mean, std) + if self.hparams.overwrite_representation: + ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") + self.model.representation_model.load_state_dict( + ckpt["representation_model"] + ) + + if self.hparams.freeze_representation: + for p in self.model.representation_model.parameters(): + p.requires_grad = False + # initialize exponential smoothing self.ema = None self._reset_ema_dict() diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index dd53c9a57..43759a461 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -121,8 +121,9 @@ def get_argparse(): parser.add_argument('--wandb-project', default='training_', type=str, help='Define what wandb Project to log to') parser.add_argument('--wandb-resume-from-id', default=None, type=str, help='Resume a wandb run from a given run id. The id can be retrieved from the wandb dashboard') parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') - parser.add_argument('--freeze_representation', type=bool, default=False, help='Freeze the representation model parameters during training. This option is only used if the training is not starting from scratch.') + parser.add_argument('--freeze_representation', type=bool, default=False, help='Freeze the representation model parameters during training.') parser.add_argument('--reset_output_model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.') + parser.add_argument('--overwrite_representation', type=str, help='After loading/creating the model, overwrite the weights of the representation model using the ones stored in the checkpoint provided in this argument.') # fmt: on return parser From f54adb33717e7044e0a4e424d7b29af96ef237a0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 Apr 2024 11:53:34 +0200 Subject: [PATCH 3/9] Fix typo --- torchmdnet/scripts/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 43759a461..f6253cbbe 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -121,9 +121,9 @@ def get_argparse(): parser.add_argument('--wandb-project', default='training_', type=str, help='Define what wandb Project to log to') parser.add_argument('--wandb-resume-from-id', default=None, type=str, help='Resume a wandb run from a given run id. The id can be retrieved from the wandb dashboard') parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') - parser.add_argument('--freeze_representation', type=bool, default=False, help='Freeze the representation model parameters during training.') - parser.add_argument('--reset_output_model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.') - parser.add_argument('--overwrite_representation', type=str, help='After loading/creating the model, overwrite the weights of the representation model using the ones stored in the checkpoint provided in this argument.') + parser.add_argument('--freeze-representation', type=bool, default=False, help='Freeze the representation model parameters during training.') + parser.add_argument('--reset-output-model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.') + parser.add_argument('--overwrite-representation', type=str, help='After loading/creating the model, overwrite the weights of the representation model using the ones stored in the checkpoint provided in this argument.') # fmt: on return parser From 142ea3eb3678fe67fd6d108d89fec4d2da6afde9 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 Apr 2024 12:09:31 +0200 Subject: [PATCH 4/9] Use hparams instead of hparams --- torchmdnet/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 2cc7293d8..79f379493 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -75,13 +75,13 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - if self.hparams.overwrite_representation: - ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") + if hparams["overwrite_representation"] is not None: + ckpt = torch.load(hparams["overwrite_representation"], map_location="cpu") self.model.representation_model.load_state_dict( ckpt["representation_model"] ) - if self.hparams.freeze_representation: + if hparams["freeze_representation"]: for p in self.model.representation_model.parameters(): p.requires_grad = False From 97b556b58024d3dfdc316238f93dddc2d77527d9 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 Apr 2024 12:27:03 +0200 Subject: [PATCH 5/9] Fix --- torchmdnet/module.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 79f379493..434973a8c 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -65,6 +65,12 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False + if "overwrite_representation" not in hparams: + hparams["overwrite_representation"] = None + if "freeze_representation" not in hparams: + hparams["freeze_representation"] = False + if "reset_output_model" not in hparams: + hparams["reset_output_model"] = False self.save_hyperparameters(hparams) From 96c29202cf61fad073978ec795ef102769e34319 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 21 May 2024 16:25:14 +0200 Subject: [PATCH 6/9] Fix hparams --- torchmdnet/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 434973a8c..de312893b 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -81,13 +81,13 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - if hparams["overwrite_representation"] is not None: - ckpt = torch.load(hparams["overwrite_representation"], map_location="cpu") + if self.hparams.overwrite_representation is not None: + ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") self.model.representation_model.load_state_dict( ckpt["representation_model"] ) - if hparams["freeze_representation"]: + if self.hparams.freeze_representation: for p in self.model.representation_model.parameters(): p.requires_grad = False From 9f9d267445396c8d4a7b30a70084a10f278abf2d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 21 May 2024 16:28:43 +0200 Subject: [PATCH 7/9] Extract representation model --- torchmdnet/module.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index de312893b..e40fafb86 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -48,6 +48,16 @@ def forward(self, data): return data +def extract_representation_model(state_dict): + representation_model = {} + prefix = "model.representation_model." + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix) :] + representation_model[new_key] = value + return representation_model + + class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. @@ -83,9 +93,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): if self.hparams.overwrite_representation is not None: ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") - self.model.representation_model.load_state_dict( - ckpt["representation_model"] - ) + state_dict = extract_representation_model(ckpt["state_dict"]) + self.model.representation_model.load_state_dict(state_dict) if self.hparams.freeze_representation: for p in self.model.representation_model.parameters(): From 5c196ef9a2004114c0be74eba76b3f801baccac9 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 21 May 2024 16:30:04 +0200 Subject: [PATCH 8/9] store_true --- torchmdnet/scripts/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index f6253cbbe..fc40d97a4 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -121,8 +121,8 @@ def get_argparse(): parser.add_argument('--wandb-project', default='training_', type=str, help='Define what wandb Project to log to') parser.add_argument('--wandb-resume-from-id', default=None, type=str, help='Resume a wandb run from a given run id. The id can be retrieved from the wandb dashboard') parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') - parser.add_argument('--freeze-representation', type=bool, default=False, help='Freeze the representation model parameters during training.') - parser.add_argument('--reset-output-model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.') + parser.add_argument('--freeze-representation', type=bool, default=False, help='Freeze the representation model parameters during training.', action='store_true') + parser.add_argument('--reset-output-model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.', action='store_true') parser.add_argument('--overwrite-representation', type=str, help='After loading/creating the model, overwrite the weights of the representation model using the ones stored in the checkpoint provided in this argument.') # fmt: on return parser From c65007149bf0c1f425bb16b372634d5e906bb357 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 21 May 2024 16:30:48 +0200 Subject: [PATCH 9/9] store_true --- torchmdnet/scripts/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index fc40d97a4..70d8d3ba8 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -121,8 +121,8 @@ def get_argparse(): parser.add_argument('--wandb-project', default='training_', type=str, help='Define what wandb Project to log to') parser.add_argument('--wandb-resume-from-id', default=None, type=str, help='Resume a wandb run from a given run id. The id can be retrieved from the wandb dashboard') parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') - parser.add_argument('--freeze-representation', type=bool, default=False, help='Freeze the representation model parameters during training.', action='store_true') - parser.add_argument('--reset-output-model', type=bool, default=False, help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.', action='store_true') + parser.add_argument('--freeze-representation', help='Freeze the representation model parameters during training.', action='store_true') + parser.add_argument('--reset-output-model', help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.', action='store_true') parser.add_argument('--overwrite-representation', type=str, help='After loading/creating the model, overwrite the weights of the representation model using the ones stored in the checkpoint provided in this argument.') # fmt: on return parser