From a3d2158764371904957744a4d1ffc1160b2771eb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 19 Jul 2023 10:34:37 +0200 Subject: [PATCH 1/3] data save hparams --- torchmdnet/scripts/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 6b120df86..6d652c66e 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -181,6 +181,7 @@ def main(): # run test set after completing the fit model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + data.save_hyperparameters(model.hparams) trainer = pl.Trainer(logger=_logger) trainer.test(model, data) From cccb9e6cc99fe0f50a8fa612a354a45d2adf1cef Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 19 Jul 2023 11:11:32 +0200 Subject: [PATCH 2/3] correct method --- torchmdnet/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 6d652c66e..8f14febc6 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -181,7 +181,7 @@ def main(): # run test set after completing the fit model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) - data.save_hyperparameters(model.hparams) + data.hparams.update(model.hparams) trainer = pl.Trainer(logger=_logger) trainer.test(model, data) From 684a8f089984de009b9869751b97a79c5b7db5c3 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 23 Aug 2023 17:16:26 +0200 Subject: [PATCH 3/3] new arg to load_ckpt method --- torchmdnet/scripts/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index cd7874a38..f6b63c9b6 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -16,6 +16,7 @@ from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number import torch + def get_args(): # fmt: off parser = argparse.ArgumentParser(description='Training') @@ -179,9 +180,12 @@ def main(): trainer.fit(model, data) # run test set after completing the fit - model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) - data.hparams.update(model.hparams) + model = LNNP.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path, + hparams_file=f"{args.log_dir}/input.yaml", + ) trainer = pl.Trainer(logger=_logger) + trainer.test(model, data)