From 43d8f9f10f9da96d7c4fd7f1bab0b8a603b9fcf7 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:42:11 +0100 Subject: [PATCH 01/82] Update tensornet.py --- torchmdnet/models/tensornet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e7..2b830ce19 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -239,7 +239,8 @@ def forward( # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - else: + # if not atom-wise, make atom-wise (pq is already atom-wise) + if z.shape != q.shape: q = q[batch] zp = z if self.static_shapes: From 9ccc2c0bc98c38d0f641b7e785a03d8ef8c53502 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:44:00 +0100 Subject: [PATCH 02/82] Update model.py --- torchmdnet/models/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..a8788dc64 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -378,6 +378,10 @@ def forward( assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch + # trick to incorporate SPICE pqs + # set charge: true in yaml + q = extra_args["pq"] + if self.derivative: pos.requires_grad_(True) # run the potentially wrapped representation model From 25854fb2ce03888108277669e32e2768022514a3 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:46:44 +0100 Subject: [PATCH 03/82] Update model.py --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a8788dc64..015d629f7 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -379,7 +379,7 @@ def forward( batch = torch.zeros_like(z) if batch is None else batch # trick to incorporate SPICE pqs - # set charge: true in yaml + # set charge: true in yaml ((?) currently I do it) q = extra_args["pq"] if self.derivative: From cc6de7a1bcc9c73767fbddb21995d6e599ff2730 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:09:36 +0100 Subject: [PATCH 04/82] move to extra_fields implementation --- torchmdnet/models/model.py | 55 +++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 015d629f7..67c75bf42 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import re -from typing import Optional, List, Tuple, Dict +from typing import Optional, List, Tuple, Dict, Any import torch from torch.autograd import grad from torch import nn, Tensor @@ -38,7 +38,18 @@ def create_model(args, prior_model=None, mean=None, std=None): args["static_shapes"] = False if "vector_cutoff" not in args: args["vector_cutoff"] = False - + + # Here we introduce the extra_fields_args, which is Dict[str, Any] + # These could be used from each model to initialize nn.embedding layers, nn.Parameter, etc. + if "extra_fields" not in args: + extra_fields = None + elif isinstance(args["extra_fields"], str): + extra_fields = {args["extra_fields"]: None} + elif isinstance(args["extra_fields"], list): + extra_fields = {label: None for label in args["extra_fields"]} + else: + extra_fields = args["extra_fields"] + shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -57,6 +68,7 @@ def create_model(args, prior_model=None, mean=None, std=None): else None ), dtype=dtype, + extra_fields=extra_fields, ) # representation network @@ -263,8 +275,8 @@ class TorchMD_Net(nn.Module): Parameters ---------- representation_model : nn.Module - A model that takes as input the atomic numbers, positions, batch indices, and optionally - charges and spins. It must return a tuple of the form (x, v, z, pos, batch), where x + A model that takes as input the atomic numbers, positions, batch indices. + It must return a tuple of the form (x, v, z, pos, batch), where x are the atom features, v are the vector features (if any), z are the atomic numbers, pos are the positions, and batch are the batch indices. See TorchMD_ET for more details. output_model : nn.Module @@ -336,9 +348,8 @@ def forward( pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, - extra_args: Optional[Dict[str, Tensor]] = None, + extra_args: Optional[Dict[str, Optional[Tensor]]] = None, + extra_fields: Optional[Dict[str, Any]] = None, ) -> Tuple[Tensor, Tensor]: """ Compute the output of the model. @@ -368,9 +379,8 @@ def forward( The vectors defining the periodic box. This must have shape `(3, 3)`, where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. If this is omitted, periodic boundary conditions are not applied. - q (Tensor, optional): Atomic charges in the molecule. Shape: (N,). - s (Tensor, optional): Atomic spins in the molecule. Shape: (N,). - extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model or to the representation model. + extra_fields (Dict[str, Tensor], optional): Extra fields to pass to the representation model. Returns: Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise. @@ -378,15 +388,30 @@ def forward( assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch - # trick to incorporate SPICE pqs - # set charge: true in yaml ((?) currently I do it) - q = extra_args["pq"] - if self.derivative: pos.requires_grad_(True) + + # recover the extra_fields_args from the extra_fields + if self.representation_model.extra_fields is None: + extra_fields_args = None + else: + assert extra_args is not None, "Extra fields are required but not provided." + extra_fields_args = {} + for field in extra_fields.keys(): + t = extra_args[field] + if t.shape != z.shape: + # expand molecular label to atom labels + t = t[batch] + extra_fields_args[field] = t + # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( - z, pos, batch, box=box, q=q, s=s + z, + pos, + batch, + box=box, + extra_args=extra_args, + extra_fields_args=extra_fields_args, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From 9478a927d82c06067bc0d3dfcf158705e2388a6f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:10:24 +0100 Subject: [PATCH 05/82] remove charge and spin, this will go to extra_fields --- torchmdnet/module.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915e..da8991718 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -8,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.nn.functional import local_response_norm, mse_loss, l1_loss from torch import Tensor -from typing import Optional, Dict, Tuple +from typing import Optional, Dict, Tuple, Any from lightning import LightningModule from torchmdnet.models.model import create_model, load_model @@ -61,11 +61,6 @@ class LNNP(LightningModule): def __init__(self, hparams, prior_model=None, mean=None, std=None): super(LNNP, self).__init__() - if "charge" not in hparams: - hparams["charge"] = False - if "spin" not in hparams: - hparams["spin"] = False - self.save_hyperparameters(hparams) if self.hparams.load_model: @@ -119,11 +114,9 @@ def forward( pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: - return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args) + return self.model(z, pos, batch=batch, box=box, extra_args=extra_args) def training_step(self, batch, batch_idx): return self.step(batch, [mse_loss], "train") @@ -199,7 +192,7 @@ def step(self, batch, loss_fn_list, stage): batch = self.data_transform(batch) with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): extra_args = batch.to_dict() - for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"): + for a in ("y", "neg_dy", "z", "pos", "batch", "box"): if a in extra_args: del extra_args[a] # TODO: the model doesn't necessarily need to return a derivative once @@ -209,8 +202,6 @@ def step(self, batch, loss_fn_list, stage): batch.pos, batch=batch.batch, box=batch.box if "box" in batch else None, - q=batch.q if self.hparams.charge else None, - s=batch.s if self.hparams.spin else None, extra_args=extra_args, ) if self.hparams.derivative and "y" not in batch: From c7014e7c20676a50fd69bfd3c7038860f3d8ba31 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:11:37 +0100 Subject: [PATCH 06/82] force hdf5 dataset to common-unique data structure --- torchmdnet/datasets/hdf.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index 3d817a50c..cd5d19f5b 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -57,6 +57,15 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs): self.fields.append( ("partial_charges", "partial_charges", torch.float32) ) + # total charge and spin, will be load as 'q' and 's' respectively to keep the same naming convention + if "total_charge" in group: + self.fields.append( + ("total_charge", "total_charge", torch.float32) + ) + if "spin" in group: + self.fields.append( + ("spin", "spin", torch.float32) + ) assert ("energy" in group) or ( "forces" in group ), "Each group must contain at least energies or forces" From 1f4928cf7e53c231bfbebeabdf6c273eaf6098a1 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:12:54 +0100 Subject: [PATCH 07/82] force ace dataset to common-unique data structure --- torchmdnet/datasets/ace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e8c59785d..4f1f8335c 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -273,7 +273,7 @@ def sample_iter(self, mol_ids=False): # Create a sample args = dict( - z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp + z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=q, partial_charges=pq, dipole_moment=dp ) if mol_ids: args["mol_id"] = mol_id From 7f238dc516e22f8d1235863ca073188b2be347af Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:13:51 +0100 Subject: [PATCH 08/82] remove charge and spin flag from train.py --- torchmdnet/scripts/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45f..ef05f4a10 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -77,8 +77,6 @@ def get_argparse(): parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*") # architectural args - parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') - parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') From 3250fc4fceef360b072bf1c7c853c548840688d8 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:15:23 +0100 Subject: [PATCH 09/82] add extra_fields to argparse --- torchmdnet/scripts/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index ef05f4a10..224f8c14f 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -60,6 +60,7 @@ def get_argparse(): parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') # dataset specific + parser.add_argument('--extra-fields', default=None, help='Extra fields of the dataset to pass to the model, it could be a list of fields or a dictionary with the field name and addtionals arguments') parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') parser.add_argument('--dataset-arg', default=None, help='Additional dataset arguments. Needs to be a dictionary.') From 736af5410c0ba1c1f793b5056a135db0233dcf4f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:29:51 +0100 Subject: [PATCH 10/82] memdataset to common-unique data-structure --- torchmdnet/datasets/memdataset.py | 90 +++++++++++++++---------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index a97b54be8..bb3576ed6 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -17,9 +17,9 @@ class MemmappedDataset(Dataset): - :obj:`pos`: Positions of the atoms. - :obj:`y`: Energy of the conformation. - :obj:`neg_dy`: Forces on the atoms. - - :obj:`q`: Total charge of the conformation. - - :obj:`pq`: Partial charges of the atoms. - - :obj:`dp`: Dipole moment of the conformation. + - :obj:`total_charge`: Total charge of the conformation. + - :obj:`partial_charges`: Partial charges of the atoms. + - :obj:`dipole_moment`: Dipole moment of the conformation. The data is stored in the following files: @@ -28,9 +28,9 @@ class MemmappedDataset(Dataset): - :obj:`name.pos.mmap`: Positions of all the atoms. - :obj:`name.y.mmap`: Energy of each conformation. - :obj:`name.neg_dy.mmap`: Forces on all the atoms. - - :obj:`name.q.mmap`: Total charge of each conformation. - - :obj:`name.pq.mmap`: Partial charges of all the atoms. - - :obj:`name.dp.mmap`: Dipole moment of each conformation. + - :obj:`name.total_charge.mmap`: Total charge of each conformation. + - :obj:`name.partial_charges.mmap`: Partial charges of all the atoms. + - :obj:`name.dipole_moment.mmap`: Dipole moment of each conformation. Args: root (str): Root directory where the dataset should be stored. @@ -45,8 +45,8 @@ class MemmappedDataset(Dataset): indicating whether the data object should be included in the final dataset. properties (tuple of str, optional): The properties to include in the - dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`q`, - :obj:`pq`, and :obj:`dp`. + dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`total_charge`, + :obj:`partial_charges`, and :obj:`dipole_moment`. """ def __init__( @@ -55,7 +55,7 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - properties=("y", "neg_dy", "q", "pq", "dp"), + properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"), ): self.name = self.__class__.__name__ self.properties = properties @@ -76,13 +76,13 @@ def __init__( self.neg_dy_mm = np.memmap( fnames["neg_dy"], mode="r", dtype=np.float32, shape=(num_all_atoms, 3) ) - if "q" in self.properties: - self.q_mm = np.memmap(fnames["q"], mode="r", dtype=np.int8) - if "pq" in self.properties: - self.pq_mm = np.memmap(fnames["pq"], mode="r", dtype=np.float32) - if "dp" in self.properties: + if "total_charge" in self.properties: + self.q_mm = np.memmap(fnames["total_charge"], mode="r", dtype=np.int8) + if "partial_charges" in self.properties: + self.pq_mm = np.memmap(fnames["partial_charges"], mode="r", dtype=np.float32) + if "dipole_moment" in self.properties: self.dp_mm = np.memmap( - fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3) + fnames["dipole_moment"], mode="r", dtype=np.float32, shape=(num_all_confs, 3) ) assert self.idx_mm[0] == 0 @@ -151,20 +151,20 @@ def process(self): dtype=np.float32, shape=(num_all_atoms, 3), ) - if "q" in self.properties: + if "total_charge" in self.properties: q_mm = np.memmap( - fnames["q"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs + fnames["total_charge"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs ) - if "pq" in self.properties: + if "partial_charges" in self.properties: pq_mm = np.memmap( - fnames["pq"] + ".tmp", + fnames["partial_charges"] + ".tmp", mode="w+", dtype=np.float32, shape=num_all_atoms, ) - if "dp" in self.properties: + if "dipole_moment" in self.properties: dp_mm = np.memmap( - fnames["dp"] + ".tmp", + fnames["dipole_moment"] + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_confs, 3), @@ -182,12 +182,12 @@ def process(self): y_mm[i_conf] = data.y if "neg_dy" in self.properties: neg_dy_mm[i_atom:i_next_atom] = data.neg_dy - if "q" in self.properties: - q_mm[i_conf] = data.q.to(pt.int8) - if "pq" in self.properties: - pq_mm[i_atom:i_next_atom] = data.pq - if "dp" in self.properties: - dp_mm[i_conf] = data.dp + if "total_charge" in self.properties: + q_mm[i_conf] = data.total_charge.to(pt.int8) + if "partial_charges" in self.properties: + pq_mm[i_atom:i_next_atom] = data.partial_charges + if "dipole_moment" in self.properties: + dp_mm[i_conf] = data.dipole_moment i_atom = i_next_atom idx_mm[-1] = num_all_atoms @@ -200,11 +200,11 @@ def process(self): y_mm.flush() if "neg_dy" in self.properties: neg_dy_mm.flush() - if "q" in self.properties: + if "total_charge" in self.properties: q_mm.flush() - if "pq" in self.properties: + if "partial_charges" in self.properties: pq_mm.flush() - if "dp" in self.properties: + if "dipole_moment" in self.properties: dp_mm.flush() os.rename(idx_mm.filename, fnames["idx"]) @@ -214,12 +214,12 @@ def process(self): os.rename(y_mm.filename, fnames["y"]) if "neg_dy" in self.properties: os.rename(neg_dy_mm.filename, fnames["neg_dy"]) - if "q" in self.properties: - os.rename(q_mm.filename, fnames["q"]) - if "pq" in self.properties: - os.rename(pq_mm.filename, fnames["pq"]) - if "dp" in self.properties: - os.rename(dp_mm.filename, fnames["dp"]) + if "total_charge" in self.properties: + os.rename(q_mm.filename, fnames["total_charge"]) + if "partial_charges" in self.properties: + os.rename(pq_mm.filename, fnames["partial_charges"]) + if "dipole_moment" in self.properties: + os.rename(dp_mm.filename, fnames["dipole_moment"]) def len(self): return len(self.idx_mm) - 1 @@ -233,9 +233,9 @@ def get(self, idx): - :obj:`pos`: Positions of the atoms. - :obj:`y`: Formation energy of the molecule. - :obj:`neg_dy`: Forces on the atoms. - - :obj:`q`: Total charge of the molecule. - - :obj:`pq`: Partial charges of the atoms. - - :obj:`dp`: Dipole moment of the molecule. + - :obj:`total_charge`: Total charge of the molecule. + - :obj:`partial_charges`: Partial charges of the atoms. + - :obj:`dipole_moment`: Dipole moment of the molecule. Args: idx (int): Index of the data object. @@ -252,10 +252,10 @@ def get(self, idx): props["y"] = pt.tensor(self.y_mm[idx]).view(1, 1) if "neg_dy" in self.properties: props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms]) - if "q" in self.properties: - props["q"] = pt.tensor(self.q_mm[idx], dtype=pt.long) - if "pq" in self.properties: - props["pq"] = pt.tensor(self.pq_mm[atoms]) - if "dp" in self.properties: - props["dp"] = pt.tensor(self.dp_mm[idx]) + if "total_charge" in self.properties: + props["total_charge"] = pt.tensor(self.q_mm[idx], dtype=pt.long) + if "partial_charges" in self.properties: + props["partial_charges"] = pt.tensor(self.pq_mm[atoms]) + if "dipole_moment" in self.properties: + props["dipole_moment"] = pt.tensor(self.dp_mm[idx]) return Data(z=z, pos=pos, **props) From d4950446613adf6995ad25814d32e4052b196fb7 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:30:19 +0100 Subject: [PATCH 11/82] qm9q to common-unique data structure --- torchmdnet/datasets/qm9q.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/qm9q.py b/torchmdnet/datasets/qm9q.py index 63a262a30..d79bdfd37 100644 --- a/torchmdnet/datasets/qm9q.py +++ b/torchmdnet/datasets/qm9q.py @@ -150,7 +150,7 @@ def sample_iter(self, mol_ids=False): # Create a sample args = dict( - z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp + z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=q, partial_charges=pq, dipole_moment=dp ) if mol_ids: args["mol_id"] = mol_id From 014a06cdc5b0cb7ab271a70b48c79d41807c55dc Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:37:27 +0100 Subject: [PATCH 12/82] ET to extra_fields --- torchmdnet/models/torchmd_et.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 5ff168d54..955965627 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -102,6 +102,7 @@ def __init__( box_vecs=None, vector_cutoff=False, dtype=torch.float32, + extra_fields=None, ): super(TorchMD_ET, self).__init__() @@ -194,8 +195,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, + extra_fields_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) From b081a091c641dd9eb16a9bb79b9422e76f1191ff Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:38:48 +0100 Subject: [PATCH 13/82] add extra_fields documentation to the ET --- torchmdnet/models/torchmd_et.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 955965627..b96874a8c 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -79,6 +79,9 @@ class TorchMD_ET(nn.Module): (default: :obj:`False`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, + for example extra_labels={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_labels={'total_charge': {embedding_dims: 64}. + default: :obj:`None`) """ From 2910694717ce6d55dda37330e6984de4e00ef92a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:41:20 +0100 Subject: [PATCH 14/82] transformer to extra_fields --- torchmdnet/models/torchmd_t.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index c11efc080..2bdd1c751 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -76,6 +76,9 @@ class TorchMD_T(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, + for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. + default: :obj:`None`) """ @@ -98,6 +101,7 @@ def __init__( max_num_neighbors=32, dtype=torch.float, box_vecs=None, + extra_fields=None ): super(TorchMD_T, self).__init__() @@ -190,8 +194,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - s: Optional[Tensor] = None, - q: Optional[Tensor] = None, + extra_fields_args: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) From c0489afc71fb175f04397c776cf66463032d11f6 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:42:07 +0100 Subject: [PATCH 15/82] small fix in ET documentation --- torchmdnet/models/torchmd_et.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index b96874a8c..a7b56e3d1 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -80,7 +80,7 @@ class TorchMD_ET(nn.Module): check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, - for example extra_labels={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_labels={'total_charge': {embedding_dims: 64}. + for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. default: :obj:`None`) """ From cb2229a06887ec8f1126e28b68a00b28d3084190 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:43:43 +0100 Subject: [PATCH 16/82] graph-network to extra_fields --- torchmdnet/models/torchmd_gn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 31d68ae03..224b44f33 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -86,6 +86,9 @@ class TorchMD_GN(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, + for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. + default: :obj:`None`) """ @@ -107,6 +110,7 @@ def __init__( aggr="add", dtype=torch.float32, box_vecs=None, + extra_fields=None ): super(TorchMD_GN, self).__init__() @@ -196,8 +200,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - s: Optional[Tensor] = None, - q: Optional[Tensor] = None, + extra_fields_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) From 2fe05bc21c9da5777a5e77d97b33aff03401d382 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 16:58:31 +0100 Subject: [PATCH 17/82] remove optional tensor for extra args, it's needed by default --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 67c75bf42..158d18cce 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -348,7 +348,7 @@ def forward( pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None, - extra_args: Optional[Dict[str, Optional[Tensor]]] = None, + extra_args: Optional[Dict[str, Tensor]] = None, extra_fields: Optional[Dict[str, Any]] = None, ) -> Tuple[Tensor, Tensor]: """ From 31506723269b54a137efff7f89b1dd4c212b6731 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 17:03:18 +0100 Subject: [PATCH 18/82] remove extra_args from model forward, extra_fields_args it's only needed --- torchmdnet/models/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 158d18cce..e3f512769 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -410,7 +410,6 @@ def forward( pos, batch, box=box, - extra_args=extra_args, extra_fields_args=extra_fields_args, ) # apply the output network From 6055357dd8fba56a95de6e3fdc67b413ed6480fb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 17:03:53 +0100 Subject: [PATCH 19/82] add self.extra_fields to architectures --- torchmdnet/models/torchmd_et.py | 1 + torchmdnet/models/torchmd_gn.py | 2 +- torchmdnet/models/torchmd_t.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index a7b56e3d1..c7af20ef8 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -137,6 +137,7 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.dtype = dtype + self.extra_fields = extra_fields act_class = act_class_mapping[activation] diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 224b44f33..29d1e9439 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -140,7 +140,7 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.aggr = aggr - + self.extra_fields = extra_fields act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 2bdd1c751..0ff7685b8 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -128,6 +128,7 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z + self.extra_fields = extra_fields act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] From e3ebbbec939749d1459203ac844c949b307fe5f8 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 6 Mar 2024 17:20:55 +0100 Subject: [PATCH 20/82] remove all 'q' specific function, to move to more general extra_fields --- torchmdnet/models/tensornet.py | 35 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 2b830ce19..1fb517787 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import torch -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, Any from torch import Tensor, nn from torchmdnet.models.utils import ( CosineCutoff, @@ -16,7 +16,6 @@ torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True - def vector_to_skewtensor(vector): """Creates a skew-symmetric tensor from a vector.""" batch_size = vector.size(0) @@ -120,6 +119,9 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, + for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. + default: :obj:`None`) """ def __init__( @@ -139,6 +141,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, + extra_fields=None, ): super(TensorNet, self).__init__() @@ -163,6 +166,7 @@ def __init__( self.activation = activation self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper + self.extra_fields = extra_fields act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -176,6 +180,7 @@ def __init__( trainable_rbf, max_z, dtype, + extra_fields, ) self.layers = nn.ModuleList() @@ -226,8 +231,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, + extra_fields_args: Dict[str, Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) @@ -236,17 +240,10 @@ def forward( edge_vec is not None ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom - # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q - if q is None: - q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - # if not atom-wise, make atom-wise (pq is already atom-wise) - if z.shape != q.shape: - q = q[batch] zp = z if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -259,9 +256,9 @@ def forward( # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) - X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_fields_args) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, q) + X = layer(X, edge_index, edge_weight, edge_attr, extra_fields_args) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -288,6 +285,7 @@ def __init__( trainable_rbf=False, max_z=128, dtype=torch.float32, + extra_fields=None, ): super(TensorEmbedding, self).__init__() @@ -327,7 +325,7 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: + def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_fields_args: Optional[Dict[str, Tensor]]) -> Tensor: Z = self.emb(z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( @@ -363,8 +361,9 @@ def forward( edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor, + extra_fields_args: Dict[str, Any] = None, ) -> Tensor: - Zij = self._get_atomic_number_message(z, edge_index) + Zij = self._get_atomic_number_message(z, edge_index, extra_fields_args) Iij, Aij, Sij = self._get_tensor_messages( Zij, edge_weight, edge_vec_norm, edge_attr ) @@ -457,7 +456,7 @@ def forward( edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - q: Tensor, + extra_fields_args: Optional[Dict[str, Tensor]] = None, ) -> Tensor: C = self.cutoff(edge_weight) for linear_scalar in self.linears_scalar: @@ -484,7 +483,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B)) + I, A, S = decompose_tensor((1) * (A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) @@ -494,5 +493,5 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2) + X = X + dX + (1) * torch.matrix_power(dX, 2) return X From ea3acd69a66509857fab77a276e0ab53490da18c Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:19:53 +0100 Subject: [PATCH 21/82] change variable name t additional_labels and allow to be only a dict --- torchmdnet/models/model.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index e3f512769..243264e1b 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -40,16 +40,15 @@ def create_model(args, prior_model=None, mean=None, std=None): args["vector_cutoff"] = False # Here we introduce the extra_fields_args, which is Dict[str, Any] - # These could be used from each model to initialize nn.embedding layers, nn.Parameter, etc. - if "extra_fields" not in args: - extra_fields = None - elif isinstance(args["extra_fields"], str): - extra_fields = {args["extra_fields"]: None} - elif isinstance(args["extra_fields"], list): - extra_fields = {label: None for label in args["extra_fields"]} + # This could be used from each model to initialize nn.embedding layers, nn.Parameter, etc. + if "additional_labels" not in args: + additional_labels = None + elif isinstance(args["additional_labels"], dict): + additional_labels = args["additional_labels"] else: - extra_fields = args["extra_fields"] - + additional_labels = None + warnings.warn("Extra fields should be a dictionary. Ignoring extra fields.") + shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -68,7 +67,7 @@ def create_model(args, prior_model=None, mean=None, std=None): else None ), dtype=dtype, - extra_fields=extra_fields, + additional_labels=additional_labels, ) # representation network From c34fc527e8e48f9c4b942c8de4e71bf9ec75e89d Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:20:44 +0100 Subject: [PATCH 22/82] remove architectural redundancy --- torchmdnet/models/model.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 243264e1b..1ab302abb 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -348,7 +348,6 @@ def forward( batch: Optional[Tensor] = None, box: Optional[Tensor] = None, extra_args: Optional[Dict[str, Tensor]] = None, - extra_fields: Optional[Dict[str, Any]] = None, ) -> Tuple[Tensor, Tensor]: """ Compute the output of the model. @@ -378,8 +377,7 @@ def forward( The vectors defining the periodic box. This must have shape `(3, 3)`, where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. If this is omitted, periodic boundary conditions are not applied. - extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model or to the representation model. - extra_fields (Dict[str, Tensor], optional): Extra fields to pass to the representation model. + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model and to the representation model. Returns: Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise. @@ -389,27 +387,14 @@ def forward( if self.derivative: pos.requires_grad_(True) - - # recover the extra_fields_args from the extra_fields - if self.representation_model.extra_fields is None: - extra_fields_args = None - else: - assert extra_args is not None, "Extra fields are required but not provided." - extra_fields_args = {} - for field in extra_fields.keys(): - t = extra_args[field] - if t.shape != z.shape: - # expand molecular label to atom labels - t = t[batch] - extra_fields_args[field] = t - + # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( z, pos, batch, box=box, - extra_fields_args=extra_fields_args, + extra_args=extra_args, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From 32f245616109ae962a83990d819972fea5e41037 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:21:30 +0100 Subject: [PATCH 23/82] move to additional_labels verion --- torchmdnet/models/torchmd_et.py | 16 +++++++++------- torchmdnet/models/torchmd_gn.py | 16 ++++++++++------ torchmdnet/models/torchmd_t.py | 15 ++++++++------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index c7af20ef8..f13eee96d 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -79,10 +79,10 @@ class TorchMD_ET(nn.Module): (default: :obj:`False`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, - for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. - default: :obj:`None`) - + additional_labels (Dict[str, Any], optional): Additional labels to be passed to the forward method of the model: + example: + additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} + (default: :obj:`None`) """ def __init__( @@ -105,7 +105,7 @@ def __init__( box_vecs=None, vector_cutoff=False, dtype=torch.float32, - extra_fields=None, + additional_labels=None, ): super(TorchMD_ET, self).__init__() @@ -137,7 +137,9 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.dtype = dtype - self.extra_fields = extra_fields + self.additional_labels = additional_labels + self.allowed_additional_labels = None + self.provided_additional_methods = None act_class = act_class_mapping[activation] @@ -199,7 +201,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - extra_fields_args: Optional[Dict[str, Tensor]] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 29d1e9439..53b1fec59 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -86,9 +86,10 @@ class TorchMD_GN(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, - for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. - default: :obj:`None`) + additional_labels (Dict[str, Any], optional): Additional labels to be passed to the forward method of the model: + example: + additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} + (default: :obj:`None`) """ @@ -110,7 +111,7 @@ def __init__( aggr="add", dtype=torch.float32, box_vecs=None, - extra_fields=None + additional_labels=None ): super(TorchMD_GN, self).__init__() @@ -140,7 +141,10 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.aggr = aggr - self.extra_fields = extra_fields + self.additional_labels = additional_labels + self.allowed_additional_labels = None + self.provided_additional_methods = None + act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) @@ -200,7 +204,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - extra_fields_args: Optional[Dict[str, Tensor]] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 0ff7685b8..6663469ab 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -76,10 +76,10 @@ class TorchMD_T(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, - for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. - default: :obj:`None`) - + additional_labels (Dict[str, Any], optional): Additional labels to be passed to the forward method of the model: + example: + additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} + (default: :obj:`None`) """ def __init__( @@ -101,7 +101,7 @@ def __init__( max_num_neighbors=32, dtype=torch.float, box_vecs=None, - extra_fields=None + additional_labels=None ): super(TorchMD_T, self).__init__() @@ -128,8 +128,9 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z - self.extra_fields = extra_fields - + self.additional_labels = additional_labels + self.allowed_additional_labels = None + self.provided_additional_methods = None act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] From ef4cfbec47ef91284ba3f2a845c2661d0ba87adb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:21:44 +0100 Subject: [PATCH 24/82] use extra_args --- torchmdnet/models/wrappers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/wrappers.py b/torchmdnet/models/wrappers.py index 444805e06..ce987ce01 100644 --- a/torchmdnet/models/wrappers.py +++ b/torchmdnet/models/wrappers.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) from abc import abstractmethod, ABCMeta -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict from torch import nn, Tensor @@ -43,10 +43,9 @@ def forward( z: Tensor, pos: Tensor, batch: Tensor, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - x, v, z, pos, batch = self.model(z, pos, batch=batch, q=q, s=s) + x, v, z, pos, batch = self.model(z, pos, batch=batch, extra_args=extra_args) n_samples = len(batch.unique()) From ebb08c227aa5277e3848d8f9a155946642438492 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:42:00 +0100 Subject: [PATCH 25/82] tnsnet v2 with tensornetQ class as additional method --- torchmdnet/models/tensornet.py | 81 ++++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 1fb517787..7d7a61d19 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -119,9 +119,9 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - extra_fields (Dict[str, Any], optional): Extra fields to be passed to the model, the value could be a dict with some extra args to be passed to the model, - for example extra_fields={'total_charge': {initial_value: 0.0, learnable: True}} or maybe extra_fields={'total_charge': {embedding_dims: 64}. - default: :obj:`None`) + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} + (default: :obj:`None`) """ def __init__( @@ -141,7 +141,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, - extra_fields=None, + additional_labels=None, ): super(TensorNet, self).__init__() @@ -166,7 +166,17 @@ def __init__( self.activation = activation self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper - self.extra_fields = extra_fields + self.additional_labels = additional_labels + # initialize additional methods as None if not provided + self.additional_methods = None + + if additional_labels is not None: + self.additional_methods = {} + for method_name, method_args in additional_labels.items(): + # the key of the additional_methods is the label of the method (total_charge, partial_charges, etc.) + # this will be useful for static shapes processing if needed + self.additional_methods[method_args['label']] = self.initialize_additional_method(method_name, method_args) + act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -180,9 +190,7 @@ def __init__( trainable_rbf, max_z, dtype, - extra_fields, ) - self.layers = nn.ModuleList() if num_layers != 0: for _ in range(num_layers): @@ -195,6 +203,7 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype, + self.additional_methods, ) ) self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype) @@ -218,6 +227,12 @@ def __init__( self.reset_parameters() + def initialize_additional_method(self, method, args): + if method == 'tensornet_q': + return TensornetQ(args['init_value'], args['label'], args['learnable']) + else: + raise NotImplementedError(f"Method {method} not implemented") + def reset_parameters(self): self.tensor_embedding.reset_parameters() for layer in self.layers: @@ -231,7 +246,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - extra_fields_args: Dict[str, Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) @@ -241,9 +256,15 @@ def forward( ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom zp = z + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) + if self.additional_labels is not None: + for label in self.additional_methods.keys(): + assert label in extra_args, f"Extra field {label} not found in extra_args" + extra_args[label] = torch.cat((extra_args[label], torch.zeros(1, device=extra_args[label].device, dtype=extra_args[label].dtype)), dim=0) + # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -256,9 +277,9 @@ def forward( # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) - X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_fields_args) + X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, extra_fields_args) + X = layer(X, edge_index, edge_weight, edge_attr, extra_args) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -285,7 +306,6 @@ def __init__( trainable_rbf=False, max_z=128, dtype=torch.float32, - extra_fields=None, ): super(TensorEmbedding, self).__init__() @@ -361,9 +381,8 @@ def forward( edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor, - extra_fields_args: Dict[str, Any] = None, - ) -> Tensor: - Zij = self._get_atomic_number_message(z, edge_index, extra_fields_args) + ) -> Tensor: + Zij = self._get_atomic_number_message(z, edge_index) Iij, Aij, Sij = self._get_tensor_messages( Zij, edge_weight, edge_vec_norm, edge_attr ) @@ -419,6 +438,7 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype=torch.float32, + addtional_methods = None, ): super(Interaction, self).__init__() @@ -442,6 +462,8 @@ def __init__( ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group + self.addtional_methods = addtional_methods + self.reset_parameters() def reset_parameters(self): @@ -456,7 +478,7 @@ def forward( edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - extra_fields_args: Optional[Dict[str, Tensor]] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tensor: C = self.cutoff(edge_weight) for linear_scalar in self.linears_scalar: @@ -480,18 +502,39 @@ def forward( edge_index, edge_attr[..., 2, None, None], S, X.shape[0] ) msg = Im + Am + Sm + + prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg) + if self.addtional_methods is not None: + for label, method in self.addtional_methods.items(): + tmp_ = method.forward(extra_args[label][..., None, None, None]) + prefactor *= tmp_ + if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1) * (A + B)) + I, A, S = decompose_tensor(prefactor * (A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(2 * B) + I, A, S = decompose_tensor(prefactor * 2 * B) normp1 = (tensor_norm(I + A + S) + 1)[..., None, None] I, A, S = I / normp1, A / normp1, S / normp1 I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) # shape: (natoms, hidden_channels, 3, 3) dX = I + A + S - X = X + dX + (1) * torch.matrix_power(dX, 2) + for label in self.additional_labels.keys(): + assert label in extra_args, f"Extra field {label} not found in extra_args" + X = X + dX + (prefactor) * torch.matrix_power(dX, 2) return X + + +class TensornetQ(nn.Module): + def __init__(self, init_value, additional_label='total_charge', learnable=False): + super().__init__() + self.prmtr = nn.Parameter(init_value, requires_grad=learnable, dtype=torch.float32) + self.allowed_labels = ['total_charge', 'partial_charges'] + assert additional_label in self.allowed_labels, f"Label {additional_label} not allowed for this method" + self.additional_label = additional_label + + def forward(self, X): + return self.prmtr * X \ No newline at end of file From 91c7cdeb750b0fc5d91506c8cd494d72682733ec Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:46:56 +0100 Subject: [PATCH 26/82] update warning message --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 1ab302abb..1af47817f 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -47,7 +47,7 @@ def create_model(args, prior_model=None, mean=None, std=None): additional_labels = args["additional_labels"] else: additional_labels = None - warnings.warn("Extra fields should be a dictionary. Ignoring extra fields.") + warnings.warn("Additional labels should be a dictionary. Ignoring additional labels.") shared_args = dict( hidden_channels=args["embedding_dimension"], From fbfafaf855e0758cfb24bfabe8f89e5c22fb2899 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 15:53:28 +0100 Subject: [PATCH 27/82] force labels to be atom_wise --- torchmdnet/models/model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 1af47817f..13fc813c5 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -388,13 +388,22 @@ def forward( if self.derivative: pos.requires_grad_(True) + if self.representation_model.additional_methods is not None: + extra_args_nnp = {} + # force the label to be atom wise + for label, t in extra_args.items(): + if label in self.representation_model.additional_methods.keys(): + if t.shape != z.shape: + t = t[batch] + extra_args_nnp[label] = t + # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( z, pos, batch, box=box, - extra_args=extra_args, + extra_args=extra_args_nnp, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From 088dd5f9a294f52f4502e4860f36dc628aec227b Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:11:34 +0100 Subject: [PATCH 28/82] remove unused arg --- torchmdnet/models/tensornet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 7d7a61d19..3836c7de0 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -167,7 +167,7 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.additional_labels = additional_labels - # initialize additional methods as None if not provided + # initialize additional methods as None if not provided, also used by module.py self.additional_methods = None if additional_labels is not None: @@ -345,7 +345,7 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_fields_args: Optional[Dict[str, Tensor]]) -> Tensor: + def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: Z = self.emb(z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( From 9479ae73ec5888520d4dc27429ca3276879c7537 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:12:09 +0100 Subject: [PATCH 29/82] fix extra_args_nnp generation --- torchmdnet/models/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 13fc813c5..6b559c240 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -388,7 +388,9 @@ def forward( if self.derivative: pos.requires_grad_(True) - if self.representation_model.additional_methods is not None: + if self.representation_model.additional_methods is None: + extra_args_nnp = None + else: extra_args_nnp = {} # force the label to be atom wise for label, t in extra_args.items(): From cd8ab2db4bdbc9291627b699b51bd976f57c5f5a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:14:10 +0100 Subject: [PATCH 30/82] remove old code residue --- torchmdnet/models/tensornet.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 3836c7de0..ee751e6e0 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -522,8 +522,6 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) # shape: (natoms, hidden_channels, 3, 3) dX = I + A + S - for label in self.additional_labels.keys(): - assert label in extra_args, f"Extra field {label} not found in extra_args" X = X + dX + (prefactor) * torch.matrix_power(dX, 2) return X From f48e4083a541682752ce9b0e8ec7af829c18c560 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:14:54 +0100 Subject: [PATCH 31/82] use correct name in forward for extra_args --- torchmdnet/models/torchmd_t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 6663469ab..ce99bc805 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -196,7 +196,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - extra_fields_args: Optional[Dict[str, Tensor]] = None + extra_args: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) From ef64fc43cff9a2e798009a0d83679eee68915885 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:15:08 +0100 Subject: [PATCH 32/82] fix arg name --- torchmdnet/models/torchmd_t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index ce99bc805..88de2e441 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -130,7 +130,7 @@ def __init__( self.max_z = max_z self.additional_labels = additional_labels self.allowed_additional_labels = None - self.provided_additional_methods = None + self.additional_methods = None act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] From de17cc0e97e6dd5640fa78d5e4f5de8467cf19a0 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:15:18 +0100 Subject: [PATCH 33/82] fix arg name --- torchmdnet/models/torchmd_gn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 53b1fec59..9d581ab25 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -143,7 +143,7 @@ def __init__( self.aggr = aggr self.additional_labels = additional_labels self.allowed_additional_labels = None - self.provided_additional_methods = None + self.additional_methods = None act_class = act_class_mapping[activation] From f947af474187cfbc8ca7b0ec0bede1c29277dd53 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 12 Mar 2024 16:15:41 +0100 Subject: [PATCH 34/82] rename to additional_methods --- torchmdnet/models/torchmd_et.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index f13eee96d..a87986b1f 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -139,7 +139,7 @@ def __init__( self.dtype = dtype self.additional_labels = additional_labels self.allowed_additional_labels = None - self.provided_additional_methods = None + self.additional_methods = None act_class = act_class_mapping[activation] From 4e00f8ccbcd7741cd017f8ed1f431fc4a4adf9da Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 13 Mar 2024 11:28:16 +0100 Subject: [PATCH 35/82] fix documentation --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index ee751e6e0..5b79b7dbe 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -120,7 +120,7 @@ class TensorNet(nn.Module): check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. - additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} + additional_labels = {method_name1: {label_name1: values}, method_name2:{ label_name2: values}, ...} (default: :obj:`None`) """ From 2206a4c967216f7307e4a3ccbd0d9e6b9039277a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 13 Mar 2024 12:30:19 +0100 Subject: [PATCH 36/82] fix ace dataloader with new extra_args name --- torchmdnet/datasets/ace.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 4f1f8335c..9554f1c1a 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -47,7 +47,7 @@ class Ace(MemmappedDataset): - `positions`: Atomic positions. Units: Angstrom. - `forces`: Forces on the atoms. Units: eV/Å. - `partial_charges`: Atomic partial charges. Units: electron charges. - - `dipole_moment` (version 1.0) or `dipole_moments` (version 2.0): Dipole moment (a vector of three components). Units: e*Å. + - `dipole_moment` (version 1.0) or `dipole_moment` (version 2.0): Dipole moment (a vector of three components). Units: e*Å. - `formation_energy` (version 1.0) or `formation_energies` (version 2.0): Formation energy. Units: eV. Each dataset should also have an `units` attribute specifying its units (i.e., `Å`, `eV`, `e*Å`). @@ -146,7 +146,7 @@ def __init__( transform, pre_transform, pre_filter, - properties=("y", "neg_dy", "q", "pq", "dp"), + properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"), ) @property @@ -188,14 +188,14 @@ def _load_confs_1_0(mol, n_atoms): assert neg_dy.shape == pos.shape assert conf["partial_charges"].attrs["units"] == "e" - pq = pt.tensor(conf["partial_charges"][:], dtype=pt.float32) - assert pq.shape == (n_atoms,) + partial_charges = pt.tensor(conf["partial_charges"][:], dtype=pt.float32) + assert partial_charges.shape == (n_atoms,) assert conf["dipole_moment"].attrs["units"] == "e*Å" - dp = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32) - assert dp.shape == (3,) + dipole_moment = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32) + assert dipole_moment.shape == (3,) - yield pos, y, neg_dy, pq, dp + yield pos, y, neg_dy, partial_charges, dipole_moment @staticmethod def _load_confs_2_0(mol, n_atoms): @@ -213,19 +213,19 @@ def _load_confs_2_0(mol, n_atoms): assert all_neg_dy.shape == all_pos.shape assert mol["partial_charges"].attrs["units"] == "e" - all_pq = pt.tensor(mol["partial_charges"][...], dtype=pt.float32) - assert all_pq.shape == (n_confs, n_atoms) + all_partial_charges = pt.tensor(mol["partial_charges"][...], dtype=pt.float32) + assert all_partial_charges.shape == (n_confs, n_atoms) assert mol["dipole_moments"].attrs["units"] == "e*Å" - all_dp = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32) - assert all_dp.shape == (n_confs, 3) + all_dipole_moment = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32) + assert all_dipole_moment.shape == (n_confs, 3) - for pos, y, neg_dy, pq, dp in zip(all_pos, all_y, all_neg_dy, all_pq, all_dp): + for pos, y, neg_dy, partial_charges, dipole_moment in zip(all_pos, all_y, all_neg_dy, all_partial_charges, all_dipole_moment): # Skip failed calculations if y.isnan(): continue - yield pos, y, neg_dy, pq, dp + yield pos, y, neg_dy, partial_charges, dipole_moment def sample_iter(self, mol_ids=False): assert self.subsample_molecules > 0 @@ -261,9 +261,9 @@ def sample_iter(self, mol_ids=False): z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) fq = pt.tensor(mol["formal_charges"], dtype=pt.long) - q = fq.sum() + total_charge = fq.sum() - for i_conf, (pos, y, neg_dy, pq, dp) in enumerate( + for i_conf, (pos, y, neg_dy, partial_charges, dipole_moment) in enumerate( load_confs(mol, n_atoms=len(z)) ): # Skip samples with large forces @@ -273,7 +273,7 @@ def sample_iter(self, mol_ids=False): # Create a sample args = dict( - z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=q, partial_charges=pq, dipole_moment=dp + z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=total_charge, partial_charges=partial_charges, dipole_moment=dipole_moment ) if mol_ids: args["mol_id"] = mol_id From 582d1effa9a6389d5892b183d4536dc16ab25aeb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 13 Mar 2024 12:31:10 +0100 Subject: [PATCH 37/82] prefactor to device and dtype --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 5b79b7dbe..6c83e905e 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -503,7 +503,7 @@ def forward( ) msg = Im + Am + Sm - prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg) + prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg).to(msg.device).to(msg.dtype) if self.addtional_methods is not None: for label, method in self.addtional_methods.items(): tmp_ = method.forward(extra_args[label][..., None, None, None]) From 7547f34db9e0077308dea59fda1354c21f9017eb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 13 Mar 2024 12:31:46 +0100 Subject: [PATCH 38/82] initialize nn.Parameter with torch tensor --- torchmdnet/models/tensornet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 6c83e905e..9801ba820 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -529,10 +529,9 @@ def forward( class TensornetQ(nn.Module): def __init__(self, init_value, additional_label='total_charge', learnable=False): super().__init__() - self.prmtr = nn.Parameter(init_value, requires_grad=learnable, dtype=torch.float32) + self.prmtr = nn.Parameter(torch.tensor(init_value), requires_grad=learnable) self.allowed_labels = ['total_charge', 'partial_charges'] assert additional_label in self.allowed_labels, f"Label {additional_label} not allowed for this method" - self.additional_label = additional_label - + def forward(self, X): return self.prmtr * X \ No newline at end of file From 23886704ca97266d2e1c9e9161a43036abb4793f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 13 Mar 2024 12:42:14 +0100 Subject: [PATCH 39/82] fix argspace name for additional labels --- torchmdnet/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 224f8c14f..4b4677d4a 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -60,7 +60,7 @@ def get_argparse(): parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') # dataset specific - parser.add_argument('--extra-fields', default=None, help='Extra fields of the dataset to pass to the model, it could be a list of fields or a dictionary with the field name and addtionals arguments') + parser.add_argument('--additional-labels', default=None, help='Additional labels to be passed to the model, must be a dict like {"method_name":{args}}') parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') parser.add_argument('--dataset-arg', default=None, help='Additional dataset arguments. Needs to be a dictionary.') From ea56bd33a71e2781ae53674145598b88ee685c22 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Wed, 13 Mar 2024 12:49:22 +0100 Subject: [PATCH 40/82] fix prefactor operation --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 9801ba820..b2d60770f 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -507,7 +507,7 @@ def forward( if self.addtional_methods is not None: for label, method in self.addtional_methods.items(): tmp_ = method.forward(extra_args[label][..., None, None, None]) - prefactor *= tmp_ + prefactor += tmp_ if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) From 89b0d31f22c26c187d2d3f462cb65ee56c0a2bcb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 14 Mar 2024 12:23:36 +0100 Subject: [PATCH 41/82] more efficient --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index b2d60770f..7baa72aa4 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -503,7 +503,7 @@ def forward( ) msg = Im + Am + Sm - prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg).to(msg.device).to(msg.dtype) + prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg, device=msg.device, dtype=msg.dtype) if self.addtional_methods is not None: for label, method in self.addtional_methods.items(): tmp_ = method.forward(extra_args[label][..., None, None, None]) From 10ee140c05ef7b2222acca31a63b852c156a970b Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 14 Mar 2024 12:24:33 +0100 Subject: [PATCH 42/82] specify also the name of the mehods in addtional_methods dict --- torchmdnet/models/tensornet.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 7baa72aa4..9d0be3c47 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -175,7 +175,7 @@ def __init__( for method_name, method_args in additional_labels.items(): # the key of the additional_methods is the label of the method (total_charge, partial_charges, etc.) # this will be useful for static shapes processing if needed - self.additional_methods[method_args['label']] = self.initialize_additional_method(method_name, method_args) + self.additional_methods[method_args['label']] = {'name': method_name, 'method':self.initialize_additional_method(method_name, method_args)} act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( @@ -505,10 +505,13 @@ def forward( prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg, device=msg.device, dtype=msg.dtype) if self.addtional_methods is not None: - for label, method in self.addtional_methods.items(): - tmp_ = method.forward(extra_args[label][..., None, None, None]) - prefactor += tmp_ - + for label, method_dict in self.addtional_methods.items(): + # appending to this list all the methods will be working in this way + if method_dict['name'] in ['tensornet_q']: + tmp_ = method_dict['method'](extra_args[label][..., None, None, None]) + #TODO: how do we want to handle prefactor if multiple methods are used here? + prefactor += tmp_ + if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) From cef4e64988fd05c6c744c16f0142df9c81fa461d Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 14 Mar 2024 12:42:21 +0100 Subject: [PATCH 43/82] remove unused trainable_rbf from tensornet embedding --- torchmdnet/models/tensornet.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 9d0be3c47..9a3f5036c 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -187,7 +187,6 @@ def __init__( act_class, cutoff_lower, cutoff_upper, - trainable_rbf, max_z, dtype, ) @@ -303,7 +302,6 @@ def __init__( activation, cutoff_lower, cutoff_upper, - trainable_rbf=False, max_z=128, dtype=torch.float32, ): From 895875031a70d9d1f0a97d8deb80eb4e680c603e Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 14 Mar 2024 12:49:01 +0100 Subject: [PATCH 44/82] update additional_labels documentation in models --- torchmdnet/models/tensornet.py | 5 ++++- torchmdnet/models/torchmd_et.py | 8 +++++--- torchmdnet/models/torchmd_gn.py | 9 +++++---- torchmdnet/models/torchmd_t.py | 9 +++++---- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 9a3f5036c..e7a82586e 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -120,7 +120,10 @@ class TensorNet(nn.Module): check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. - additional_labels = {method_name1: {label_name1: values}, method_name2:{ label_name2: values}, ...} + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} (default: :obj:`None`) """ diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index a87986b1f..4fa970172 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -79,9 +79,11 @@ class TorchMD_ET(nn.Module): (default: :obj:`False`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - additional_labels (Dict[str, Any], optional): Additional labels to be passed to the forward method of the model: - example: - additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} (default: :obj:`None`) """ diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 9d581ab25..10f805316 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -86,10 +86,11 @@ class TorchMD_GN(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - additional_labels (Dict[str, Any], optional): Additional labels to be passed to the forward method of the model: - example: - additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} - (default: :obj:`None`) + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} """ diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 88de2e441..2452d8e4a 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -76,10 +76,11 @@ class TorchMD_T(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - additional_labels (Dict[str, Any], optional): Additional labels to be passed to the forward method of the model: - example: - additional_labels = {method_name: {label_name1: values, label_name2: values, ...}, ...} - (default: :obj:`None`) + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} """ def __init__( From 4ef5aa9e38fdd69ceb885d3743fbdf3fd20e10f5 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 14 Mar 2024 13:12:06 +0100 Subject: [PATCH 45/82] remove extra_args expansion --- torchmdnet/models/model.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 6b559c240..c2c9816b2 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -387,25 +387,14 @@ def forward( if self.derivative: pos.requires_grad_(True) - - if self.representation_model.additional_methods is None: - extra_args_nnp = None - else: - extra_args_nnp = {} - # force the label to be atom wise - for label, t in extra_args.items(): - if label in self.representation_model.additional_methods.keys(): - if t.shape != z.shape: - t = t[batch] - extra_args_nnp[label] = t - + # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( z, pos, batch, box=box, - extra_args=extra_args_nnp, + extra_args=extra_args if self.representation_model.additional_methods else None, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From 815b8c3c900b8146c2b6187799e94675a02da01c Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 14 Mar 2024 13:12:32 +0100 Subject: [PATCH 46/82] add extra_args expansion inside the model --- torchmdnet/models/tensornet.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index e7a82586e..c37b1103d 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -258,7 +258,13 @@ def forward( ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom zp = z - + if extra_args is not None: + # we are assuming that extra args will be used, see model.py forward method how extra_args is passed + for label, t in extra_args.items(): + # molecule wise --> atom wise + if t.shape != z.shape: + extra_args[label] = t[batch] + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) From 3cb232e721eefae000c71d1ba27d9868eb404b29 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 09:49:42 +0100 Subject: [PATCH 47/82] add test for additional labels --- tests/test_model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index b792595b8..ec179141c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,22 +18,24 @@ @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("use_batch", [True, False]) -@mark.parametrize("explicit_q_s", [True, False]) +@mark.parametrize("use_extra_args", [True, False]) @mark.parametrize("precision", [32, 64]) -def test_forward(model_name, use_batch, explicit_q_s, precision): +@mark.parametrize("additional_labels", [None, {"tensornet_q": {"label": "total_charge", 'learnable': False, 'init_value': 0.1}}]) +def test_forward(model_name, use_batch, use_extra_args, precision, additional_labels): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) - model = create_model(load_example_args(model_name, prior_model=None, precision=precision)) + model = create_model(load_example_args(model_name, prior_model=None, precision=precision, additional_labels=additional_labels)) batch = batch if use_batch else None - if explicit_q_s: - model(z, pos, batch=batch, q=None, s=None) + if use_extra_args: + model(z, pos, batch=batch, extra_args={'total_charge': torch.zeros_like(z)}) else: + print("No Extra Args provided") model(z, pos, batch=batch) @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("output_model", output_modules.__all__) -@mark.parametrize("precision", [32,64]) +@mark.parametrize("precision", [32, 64]) def test_forward_output_modules(model_name, output_model, precision): z, pos, batch = create_example_batch() args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision) From 2c8e47b17e9bfd0642a9d17b3bca5306d15ee0cb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 09:50:14 +0100 Subject: [PATCH 48/82] add additional_labels to load_example_args for testing --- tests/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index ef8bcddb9..b533a2e14 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,6 +26,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs args["box_vecs"] = None if "remove_ref_energy" not in args: args["remove_ref_energy"] = False + if "additional_labels" not in args: + args["additional_labels"] = None for key, val in kwargs.items(): assert key in args, f"Broken test! Unknown key '{key}'." args[key] = val From 6f7fac9fcfd42c4f40024bd31f5d996554f8af10 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 09:51:01 +0100 Subject: [PATCH 49/82] double check with and --- torchmdnet/models/tensornet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index c37b1103d..992fb97ce 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -478,6 +478,7 @@ def reset_parameters(self): linear.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() + # TODO: should we reset the parameters of the additional methods here? def forward( self, @@ -511,7 +512,7 @@ def forward( msg = Im + Am + Sm prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg, device=msg.device, dtype=msg.dtype) - if self.addtional_methods is not None: + if self.addtional_methods is not None and extra_args is not None: for label, method_dict in self.addtional_methods.items(): # appending to this list all the methods will be working in this way if method_dict['name'] in ['tensornet_q']: From e1d8918d9f53589648b1857e5c03aec12762bb1e Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 11:11:56 +0100 Subject: [PATCH 50/82] fix condition when extra args are passed to the forward --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c2c9816b2..7f8a9f558 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -394,7 +394,7 @@ def forward( pos, batch, box=box, - extra_args=extra_args if self.representation_model.additional_methods else None, + extra_args = extra_args if self.representation_model.additional_methods is not None else None, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From 40234f29a127b35a11225ede41c6865f64615b16 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 11:12:10 +0100 Subject: [PATCH 51/82] update to addtional_labels --- torchmdnet/optimize.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 0c7f56513..633582582 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch as pt from NNPOps.CFConv import CFConv from NNPOps.CFConvNeighbors import CFConvNeighbors @@ -56,8 +56,7 @@ def forward( pos: pt.Tensor, batch: pt.Tensor, box: Optional[pt.Tensor] = None, - q: Optional[pt.Tensor] = None, - s: Optional[pt.Tensor] = None, + extra_args: Optional[Dict[str, pt.Tensor]] = None, ) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]: assert pt.all(batch == 0) From 3b0dbf8d54ba46b92ad57cb6b64eb9902eff0855 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 11:37:47 +0100 Subject: [PATCH 52/82] update test_examples --- tests/test_examples.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 8cd5155e6..30c90ab83 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -5,6 +5,7 @@ from pytest import mark import yaml import glob +import torch from os.path import dirname, join from torchmdnet.models.model import create_model from torchmdnet import priors @@ -27,4 +28,4 @@ def test_example_yamls(fname): z, pos, batch = create_example_batch() model(z, pos, batch) - model(z, pos, batch, q=None, s=None) + model(z, pos, batch, extra_args={"total_charge": torch.zeros_like(z)}) From 433017e76d965a495aaede6e710dbaee0dcf5d53 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 11:55:23 +0100 Subject: [PATCH 53/82] update test wrappers --- tests/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 8f8aff2aa..05b40e3f0 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -19,7 +19,7 @@ def test_atom_filter(remove_threshold, model_name): model = AtomFilter(model, remove_threshold) z, pos, batch = create_example_batch(n_atoms=100) - x, v, z, pos, batch = model(z, pos, batch, None, None) + x, v, z, pos, batch = model(z, pos, batch, None) assert (z > remove_threshold).all(), ( f"Lowest updated atomic number is {z.min()} but " From d3af9dd5242c8a4a937e1c0dc28e8fc5892a749e Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 12:08:22 +0100 Subject: [PATCH 54/82] additional_methods to torchmd_GN_optimized --- torchmdnet/optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 633582582..2ff26f15b 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -33,7 +33,7 @@ def __init__(self, model): super().__init__() self.model = model - + self.additional_methods = None self.neighbors = CFConvNeighbors(self.model.cutoff_upper) offset = self.model.distance_expansion.offset From 4f3903a04cfcfa4b9dcbab5fdc3f0e932feba0c2 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 12:26:11 +0100 Subject: [PATCH 55/82] small change, remove print from test_model --- tests/test_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index ec179141c..615cc561a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -29,7 +29,6 @@ def test_forward(model_name, use_batch, use_extra_args, precision, additional_la if use_extra_args: model(z, pos, batch=batch, extra_args={'total_charge': torch.zeros_like(z)}) else: - print("No Extra Args provided") model(z, pos, batch=batch) From 2b18f905f51c118ac1bcf30119bc6d02e4fd0669 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 15 Mar 2024 17:18:48 +0100 Subject: [PATCH 56/82] to shared extra args nomenclature --- torchmdnet/datasets/qm9q.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/qm9q.py b/torchmdnet/datasets/qm9q.py index d79bdfd37..abaf677c8 100644 --- a/torchmdnet/datasets/qm9q.py +++ b/torchmdnet/datasets/qm9q.py @@ -46,7 +46,7 @@ def __init__( transform, pre_transform, pre_filter, - properties=("y", "neg_dy", "q", "pq", "dp"), + properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"), ) @property From 604034a05d0bce29acf9eafee7628ba82a523409 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 10:08:13 +0100 Subject: [PATCH 57/82] fix dipole_moments in the documentation, ace v2.0 --- torchmdnet/datasets/ace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 9554f1c1a..d2b57638d 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -47,7 +47,7 @@ class Ace(MemmappedDataset): - `positions`: Atomic positions. Units: Angstrom. - `forces`: Forces on the atoms. Units: eV/Å. - `partial_charges`: Atomic partial charges. Units: electron charges. - - `dipole_moment` (version 1.0) or `dipole_moment` (version 2.0): Dipole moment (a vector of three components). Units: e*Å. + - `dipole_moment` (version 1.0) or `dipole_moments` (version 2.0): Dipole moment (a vector of three components). Units: e*Å. - `formation_energy` (version 1.0) or `formation_energies` (version 2.0): Formation energy. Units: eV. Each dataset should also have an `units` attribute specifying its units (i.e., `Å`, `eV`, `e*Å`). From 052261e8e1622d21fb3cad6fdcdfa9d9f09548a3 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:01:08 +0100 Subject: [PATCH 58/82] remove old comment --- torchmdnet/datasets/hdf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index cd5d19f5b..18bcadd61 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -57,7 +57,6 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs): self.fields.append( ("partial_charges", "partial_charges", torch.float32) ) - # total charge and spin, will be load as 'q' and 's' respectively to keep the same naming convention if "total_charge" in group: self.fields.append( ("total_charge", "total_charge", torch.float32) From c0902c5ab2559070ffc4d54ba0ba149b03bc42d2 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:13:20 +0100 Subject: [PATCH 59/82] update to get the additional labels from argparse as discussed in the PR --- torchmdnet/models/model.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 7f8a9f558..1d231b58d 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -39,13 +39,9 @@ def create_model(args, prior_model=None, mean=None, std=None): if "vector_cutoff" not in args: args["vector_cutoff"] = False - # Here we introduce the extra_fields_args, which is Dict[str, Any] - # This could be used from each model to initialize nn.embedding layers, nn.Parameter, etc. - if "additional_labels" not in args: - additional_labels = None - elif isinstance(args["additional_labels"], dict): - additional_labels = args["additional_labels"] - else: + # additional labels for the representation model + additional_labels = args.get("additional_labels") + if additional_labels is not None and not isinstance(additional_labels, dict): additional_labels = None warnings.warn("Additional labels should be a dictionary. Ignoring additional labels.") From 917cb7ac54d1bf211c78e817d2688f6c41ac88e4 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:22:37 +0100 Subject: [PATCH 60/82] fix documentation, include extra_args in the model's input --- torchmdnet/models/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 1d231b58d..e0ad906e6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -270,10 +270,11 @@ class TorchMD_Net(nn.Module): Parameters ---------- representation_model : nn.Module - A model that takes as input the atomic numbers, positions, batch indices. - It must return a tuple of the form (x, v, z, pos, batch), where x - are the atom features, v are the vector features (if any), z are the atomic numbers, - pos are the positions, and batch are the batch indices. See TorchMD_ET for more details. + A model that takes as input the atomic numbers, positions, batch indices and extra_args. The extra_args + are optional and will be used only if the representation model has additional_methods. It must return a + tuple of the form (x, v, z, pos, batch), where x are the atom features, v are the vector features (if any), + z are the atomic numbers, pos are the positions, and batch are the batch indices. See TorchMD_ET for more + details. output_model : nn.Module A model that takes as input the atom features, vector features (if any), atomic numbers, positions, and batch indices. See OutputModel for more details. From 452362b2f44751293949c1a81d4b1f7f48a85e98 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:22:59 +0100 Subject: [PATCH 61/82] to black --- torchmdnet/models/model.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index e0ad906e6..c47f135ec 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -38,13 +38,15 @@ def create_model(args, prior_model=None, mean=None, std=None): args["static_shapes"] = False if "vector_cutoff" not in args: args["vector_cutoff"] = False - + # additional labels for the representation model additional_labels = args.get("additional_labels") if additional_labels is not None and not isinstance(additional_labels, dict): additional_labels = None - warnings.warn("Additional labels should be a dictionary. Ignoring additional labels.") - + warnings.warn( + "Additional labels should be a dictionary. Ignoring additional labels." + ) + shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -384,14 +386,18 @@ def forward( if self.derivative: pos.requires_grad_(True) - + # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( z, pos, batch, box=box, - extra_args = extra_args if self.representation_model.additional_methods is not None else None, + extra_args=( + extra_args + if self.representation_model.additional_methods is not None + else None + ), ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From cfd0f538c2e3c3c496b856792b0398b8dc4c2673 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:31:37 +0100 Subject: [PATCH 62/82] remove Any from typyng import --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c47f135ec..42d680b31 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import re -from typing import Optional, List, Tuple, Dict, Any +from typing import Optional, List, Tuple, Dict import torch from torch.autograd import grad from torch import nn, Tensor From 5e02fc7984041e3a2e5ec4224dea1dd0da8715c6 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:35:24 +0100 Subject: [PATCH 63/82] initialize_additional_method as free standing function --- torchmdnet/models/tensornet.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 992fb97ce..0b7a70f7c 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -62,6 +62,11 @@ def tensor_norm(tensor): """Computes Frobenius norm.""" return (tensor**2).sum((-2, -1)) +def initialize_additional_method(method, args): + if method == 'tensornet_q': + return TensornetQ(args['init_value'], args['label'], args['learnable']) + else: + raise NotImplementedError(f"Method {method} not implemented") class TensorNet(nn.Module): r"""TensorNet's architecture. From @@ -178,7 +183,7 @@ def __init__( for method_name, method_args in additional_labels.items(): # the key of the additional_methods is the label of the method (total_charge, partial_charges, etc.) # this will be useful for static shapes processing if needed - self.additional_methods[method_args['label']] = {'name': method_name, 'method':self.initialize_additional_method(method_name, method_args)} + self.additional_methods[method_args['label']] = {'name': method_name, 'method': initialize_additional_method(method_name, method_args)} act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( @@ -228,12 +233,6 @@ def __init__( ) self.reset_parameters() - - def initialize_additional_method(self, method, args): - if method == 'tensornet_q': - return TensornetQ(args['init_value'], args['label'], args['learnable']) - else: - raise NotImplementedError(f"Method {method} not implemented") def reset_parameters(self): self.tensor_embedding.reset_parameters() From 3470f7ed7c858b641ac30c31e510560a29c5bd29 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 11:53:39 +0100 Subject: [PATCH 64/82] remove Any from typing import because not used --- torchmdnet/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index da8991718..a7b78893a 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -8,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.nn.functional import local_response_norm, mse_loss, l1_loss from torch import Tensor -from typing import Optional, Dict, Tuple, Any +from typing import Optional, Dict, Tuple from lightning import LightningModule from torchmdnet.models.model import create_model, load_model From 2ec850ab3367930387841f9c21f65e44765f605e Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 12:00:48 +0100 Subject: [PATCH 65/82] remove Any from typing import --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 0b7a70f7c..592e1541a 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import torch -from typing import Optional, Tuple, Dict, Any +from typing import Optional, Tuple, Dict from torch import Tensor, nn from torchmdnet.models.utils import ( CosineCutoff, From 781a3a799d6ae2beeb812304f73abb1910b653f1 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 12:04:37 +0100 Subject: [PATCH 66/82] update for loop in the forward step and extra_args/static_shapes assertion --- torchmdnet/models/tensornet.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 592e1541a..57a901494 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -257,20 +257,19 @@ def forward( ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom zp = z - if extra_args is not None: - # we are assuming that extra args will be used, see model.py forward method how extra_args is passed - for label, t in extra_args.items(): - # molecule wise --> atom wise - if t.shape != z.shape: - extra_args[label] = t[batch] - + + if self.additional_labels is not None: + assert extra_args is not None + for label in self.additional_labels: + assert label in extra_args, f"TensorNet expects {label} to be provided as part of extra_args" + if extra_args[label].shape != z.shape: + extra_args[label] = extra_args[label][batch] + if self.static_shapes: + extra_args[label] = torch.cat((extra_args[label], torch.zeros(1, device=extra_args[label].device, dtype=extra_args[label].dtype)), dim=0) + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - if self.additional_labels is not None: - for label in self.additional_methods.keys(): - assert label in extra_args, f"Extra field {label} not found in extra_args" - extra_args[label] = torch.cat((extra_args[label], torch.zeros(1, device=extra_args[label].device, dtype=extra_args[label].dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs From 1a30039eb98d3838a035b3f34519e97af5f171ec Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 12:06:01 +0100 Subject: [PATCH 67/82] send always extra_args --- torchmdnet/models/model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 42d680b31..fe2445131 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -393,11 +393,7 @@ def forward( pos, batch, box=box, - extra_args=( - extra_args - if self.representation_model.additional_methods is not None - else None - ), + extra_args=extra_args, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) From e811a5e28279bead33174d0c9afb8979503fcc3a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 12:30:07 +0100 Subject: [PATCH 68/82] add reset_parameters for tensornetQ --- torchmdnet/models/tensornet.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 57a901494..6a57137d1 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -476,7 +476,9 @@ def reset_parameters(self): linear.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() - # TODO: should we reset the parameters of the additional methods here? + if self.addtional_methods is not None: + for method in self.addtional_methods: + self.addtional_methods[method]['method'].reset_parameters() def forward( self, @@ -536,11 +538,17 @@ def forward( class TensornetQ(nn.Module): - def __init__(self, init_value, additional_label='total_charge', learnable=False): - super().__init__() - self.prmtr = nn.Parameter(torch.tensor(init_value), requires_grad=learnable) - self.allowed_labels = ['total_charge', 'partial_charges'] - assert additional_label in self.allowed_labels, f"Label {additional_label} not allowed for this method" + def __init__(self, init_value, additional_label='total_charge', learnable=False): + super().__init__() + self.prmtr = nn.Parameter(torch.tensor(init_value), requires_grad=learnable) + self.learnable = learnable + self.init_value = init_value + self.allowed_labels = ['total_charge', 'partial_charges'] + assert additional_label in self.allowed_labels, f"Label {additional_label} not allowed for this method" - def forward(self, X): - return self.prmtr * X \ No newline at end of file + def forward(self, X): + return self.prmtr * X + + def reset_parameters(self): + if self.learnable: + self.prmtr = nn.Parameter(torch.tensor(self.init_value)) \ No newline at end of file From 27f13dbeaaa7967bd5edc7fcfef3c1c11261c649 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 12:56:40 +0100 Subject: [PATCH 69/82] add an assertion to verify that additional_labels is not specified for models that do not implement it --- torchmdnet/models/torchmd_et.py | 3 +-- torchmdnet/models/torchmd_gn.py | 3 +-- torchmdnet/models/torchmd_t.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 4fa970172..10852ef38 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -140,9 +140,8 @@ def __init__( self.max_z = max_z self.dtype = dtype self.additional_labels = additional_labels - self.allowed_additional_labels = None self.additional_methods = None - + assert additional_labels is None, "equivariant-transformer does not support this feature" act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 10f805316..72c8f1502 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -143,9 +143,8 @@ def __init__( self.max_z = max_z self.aggr = aggr self.additional_labels = additional_labels - self.allowed_additional_labels = None self.additional_methods = None - + assert additional_labels is None, "graph-network does not support this feature" act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 2452d8e4a..74290c332 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -130,8 +130,8 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.additional_labels = additional_labels - self.allowed_additional_labels = None self.additional_methods = None + assert additional_labels is None, "transformer does not support this feature" act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] From 3ee76de52af5c82dc5b58a1d0929e0f446d82c07 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 14:09:44 +0100 Subject: [PATCH 70/82] reintroduce charge and spin for backward compatibility --- torchmdnet/scripts/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 4b4677d4a..34a7825d5 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -78,6 +78,8 @@ def get_argparse(): parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*") # architectural args + parser.add_argument('--charge', type=bool, default=False, help='DEPRECATED: This argument is no longer in use and is maintained only for retro-compatibility') + parser.add_argument('--spin', type=bool, default=False, help='DEPRECATED: This argument is no longer in use and is maintained only for retro-compatibility') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') @@ -138,7 +140,7 @@ def get_args(): args.inference_batch_size = args.batch_size os.makedirs(os.path.abspath(args.log_dir), exist_ok=True) - save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) + save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf", "charge", "spin"]) return args From d5a9e5f9bdbd325d21d154df7ce16daa11dbf002 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 14:15:29 +0100 Subject: [PATCH 71/82] to black --- torchmdnet/models/tensornet.py | 90 ++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 30 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 6a57137d1..54670be19 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -16,6 +16,7 @@ torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True + def vector_to_skewtensor(vector): """Creates a skew-symmetric tensor from a vector.""" batch_size = vector.size(0) @@ -62,11 +63,15 @@ def tensor_norm(tensor): """Computes Frobenius norm.""" return (tensor**2).sum((-2, -1)) + def initialize_additional_method(method, args): - if method == 'tensornet_q': - return TensornetQ(args['init_value'], args['label'], args['learnable']) + """Initialize additional methods to be used by the model. The additional methods are used to handle the extra_args provided to the model + using the addtional_labels argument.""" + if method == "tensornet_q": + return TensornetQ(args["init_value"], args["label"], args["learnable"]) else: - raise NotImplementedError(f"Method {method} not implemented") + raise NotImplementedError(f"Method {method} not implemented") + class TensorNet(nn.Module): r"""TensorNet's architecture. From @@ -175,15 +180,18 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.additional_labels = additional_labels - # initialize additional methods as None if not provided, also used by module.py + # initialize additional methods as None if not provided, also used by module.py self.additional_methods = None - + if additional_labels is not None: self.additional_methods = {} for method_name, method_args in additional_labels.items(): # the key of the additional_methods is the label of the method (total_charge, partial_charges, etc.) # this will be useful for static shapes processing if needed - self.additional_methods[method_args['label']] = {'name': method_name, 'method': initialize_additional_method(method_name, method_args)} + self.additional_methods[method_args["label"]] = { + "name": method_name, + "method": initialize_additional_method(method_name, method_args), + } act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( @@ -233,7 +241,7 @@ def __init__( ) self.reset_parameters() - + def reset_parameters(self): self.tensor_embedding.reset_parameters() for layer in self.layers: @@ -257,20 +265,32 @@ def forward( ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom zp = z - + if self.additional_labels is not None: assert extra_args is not None for label in self.additional_labels: - assert label in extra_args, f"TensorNet expects {label} to be provided as part of extra_args" + assert ( + label in extra_args + ), f"TensorNet expects {label} to be provided as part of extra_args" if extra_args[label].shape != z.shape: extra_args[label] = extra_args[label][batch] if self.static_shapes: - extra_args[label] = torch.cat((extra_args[label], torch.zeros(1, device=extra_args[label].device, dtype=extra_args[label].dtype)), dim=0) - + extra_args[label] = torch.cat( + ( + extra_args[label], + torch.zeros( + 1, + device=extra_args[label].device, + dtype=extra_args[label].dtype, + ), + ), + dim=0, + ) + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - + # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -386,7 +406,7 @@ def forward( edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor, - ) -> Tensor: + ) -> Tensor: Zij = self._get_atomic_number_message(z, edge_index) Iij, Aij, Sij = self._get_tensor_messages( Zij, edge_weight, edge_vec_norm, edge_attr @@ -443,7 +463,7 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype=torch.float32, - addtional_methods = None, + addtional_methods=None, ): super(Interaction, self).__init__() @@ -468,7 +488,7 @@ def __init__( self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group self.addtional_methods = addtional_methods - + self.reset_parameters() def reset_parameters(self): @@ -478,7 +498,7 @@ def reset_parameters(self): linear.reset_parameters() if self.addtional_methods is not None: for method in self.addtional_methods: - self.addtional_methods[method]['method'].reset_parameters() + self.addtional_methods[method]["method"].reset_parameters() def forward( self, @@ -510,16 +530,22 @@ def forward( edge_index, edge_attr[..., 2, None, None], S, X.shape[0] ) msg = Im + Am + Sm - - prefactor = 1 if self.addtional_methods is not None else torch.ones_like(msg, device=msg.device, dtype=msg.dtype) + + prefactor = ( + 1 + if self.addtional_methods is not None + else torch.ones_like(msg, device=msg.device, dtype=msg.dtype) + ) if self.addtional_methods is not None and extra_args is not None: for label, method_dict in self.addtional_methods.items(): - # appending to this list all the methods will be working in this way - if method_dict['name'] in ['tensornet_q']: - tmp_ = method_dict['method'](extra_args[label][..., None, None, None]) - #TODO: how do we want to handle prefactor if multiple methods are used here? - prefactor += tmp_ - + # appending to this list all the methods will be working in this way + if method_dict["name"] in ["tensornet_q"]: + tmp_ = method_dict["method"]( + extra_args[label][..., None, None, None] + ) + # TODO: how do we want to handle prefactor if multiple methods are used here? + prefactor += tmp_ + if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) @@ -531,24 +557,28 @@ def forward( I, A, S = I / normp1, A / normp1, S / normp1 I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) # shape: (natoms, hidden_channels, 3, 3) + S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute( + 0, 3, 1, 2 + ) # shape: (natoms, hidden_channels, 3, 3) dX = I + A + S X = X + dX + (prefactor) * torch.matrix_power(dX, 2) return X class TensornetQ(nn.Module): - def __init__(self, init_value, additional_label='total_charge', learnable=False): + def __init__(self, init_value, additional_label="total_charge", learnable=False): super().__init__() self.prmtr = nn.Parameter(torch.tensor(init_value), requires_grad=learnable) self.learnable = learnable self.init_value = init_value - self.allowed_labels = ['total_charge', 'partial_charges'] - assert additional_label in self.allowed_labels, f"Label {additional_label} not allowed for this method" - + self.allowed_labels = ["total_charge", "partial_charges"] + assert ( + additional_label in self.allowed_labels + ), f"Label {additional_label} not allowed for this method" + def forward(self, X): return self.prmtr * X def reset_parameters(self): if self.learnable: - self.prmtr = nn.Parameter(torch.tensor(self.init_value)) \ No newline at end of file + self.prmtr = nn.Parameter(torch.tensor(self.init_value)) From 75c67267de5c4a0119fcb3cb01f0d23ed4b8768d Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 14:55:52 +0100 Subject: [PATCH 72/82] remove comment --- torchmdnet/models/tensornet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 54670be19..de8d7974f 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -180,7 +180,6 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.additional_labels = additional_labels - # initialize additional methods as None if not provided, also used by module.py self.additional_methods = None if additional_labels is not None: From 591253cd9b3ddee7c7d339b5975fc49b6acb6485 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 16:02:52 +0100 Subject: [PATCH 73/82] let tensornetq's operations more efficient --- torchmdnet/models/tensornet.py | 71 +++++++++++++++------------------- 1 file changed, 31 insertions(+), 40 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index de8d7974f..d6966ada7 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -66,7 +66,7 @@ def tensor_norm(tensor): def initialize_additional_method(method, args): """Initialize additional methods to be used by the model. The additional methods are used to handle the extra_args provided to the model - using the addtional_labels argument.""" + using the additional_labels argument.""" if method == "tensornet_q": return TensornetQ(args["init_value"], args["label"], args["learnable"]) else: @@ -265,26 +265,26 @@ def forward( # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom zp = z - if self.additional_labels is not None: + if self.additional_methods is not None: assert extra_args is not None - for label in self.additional_labels: - assert ( - label in extra_args - ), f"TensorNet expects {label} to be provided as part of extra_args" - if extra_args[label].shape != z.shape: - extra_args[label] = extra_args[label][batch] - if self.static_shapes: - extra_args[label] = torch.cat( - ( - extra_args[label], - torch.zeros( - 1, - device=extra_args[label].device, - dtype=extra_args[label].dtype, + for label in self.additional_methods.keys(): + assert ( + label in extra_args + ), f"TensorNet expects {label} to be provided as part of extra_args" + if extra_args[label].shape != z.shape: + extra_args[label] = extra_args[label][batch] + if self.static_shapes: + extra_args[label] = torch.cat( + ( + extra_args[label], + torch.zeros( + 1, + device=extra_args[label].device, + dtype=extra_args[label].dtype, + ), ), - ), - dim=0, - ) + dim=0, + ) if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) @@ -462,7 +462,7 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype=torch.float32, - addtional_methods=None, + additional_methods=None, ): super(Interaction, self).__init__() @@ -486,7 +486,7 @@ def __init__( ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group - self.addtional_methods = addtional_methods + self.additional_methods = additional_methods self.reset_parameters() @@ -495,9 +495,9 @@ def reset_parameters(self): linear.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() - if self.addtional_methods is not None: - for method in self.addtional_methods: - self.addtional_methods[method]["method"].reset_parameters() + if self.additional_methods is not None: + for method in self.additional_methods: + self.additional_methods[method]["method"].reset_parameters() def forward( self, @@ -530,21 +530,12 @@ def forward( ) msg = Im + Am + Sm - prefactor = ( - 1 - if self.addtional_methods is not None - else torch.ones_like(msg, device=msg.device, dtype=msg.dtype) - ) - if self.addtional_methods is not None and extra_args is not None: - for label, method_dict in self.addtional_methods.items(): - # appending to this list all the methods will be working in this way - if method_dict["name"] in ["tensornet_q"]: - tmp_ = method_dict["method"]( - extra_args[label][..., None, None, None] - ) - # TODO: how do we want to handle prefactor if multiple methods are used here? - prefactor += tmp_ - + prefactor = torch.tensor([1], device=X.device, dtype=X.dtype) + if self.additional_methods is not None and extra_args is not None: + tensorq_labels = [callback['method'](extra_args[label]) for label, callback in self.additional_methods.items() if callback['name'] == 'tensornet_q'] + if len(tensorq_labels) > 0: + prefactor = prefactor + torch.prod(torch.stack(tensorq_labels), dim=0) + if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) @@ -576,7 +567,7 @@ def __init__(self, init_value, additional_label="total_charge", learnable=False) ), f"Label {additional_label} not allowed for this method" def forward(self, X): - return self.prmtr * X + return self.prmtr * X[..., None, None, None] def reset_parameters(self): if self.learnable: From 4cd63135f2f5e2e80d9436289fcb945bd3a136b9 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 16:25:33 +0100 Subject: [PATCH 74/82] move to label_callbacks instead of additional_methods --- torchmdnet/models/model.py | 9 ++++----- torchmdnet/models/tensornet.py | 28 ++++++++++++++-------------- torchmdnet/models/torchmd_et.py | 2 +- torchmdnet/models/torchmd_gn.py | 2 +- torchmdnet/models/torchmd_t.py | 2 +- torchmdnet/optimize.py | 2 +- 6 files changed, 22 insertions(+), 23 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index fe2445131..7979118af 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -272,11 +272,10 @@ class TorchMD_Net(nn.Module): Parameters ---------- representation_model : nn.Module - A model that takes as input the atomic numbers, positions, batch indices and extra_args. The extra_args - are optional and will be used only if the representation model has additional_methods. It must return a - tuple of the form (x, v, z, pos, batch), where x are the atom features, v are the vector features (if any), - z are the atomic numbers, pos are the positions, and batch are the batch indices. See TorchMD_ET for more - details. + A model that takes as input the atomic numbers, positions, batch indices and extra_args(optional). It must + return a tuple of the form (x, v, z, pos, batch), where x are the atom features, v are the vector features + (if any), z are the atomic numbers, pos are the positions, and batch are the batch indices. See TorchMD_ET + for more details. output_model : nn.Module A model that takes as input the atom features, vector features (if any), atomic numbers, positions, and batch indices. See OutputModel for more details. diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index d6966ada7..49fc6af14 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -180,14 +180,14 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.additional_labels = additional_labels - self.additional_methods = None + self.label_callbacks = None if additional_labels is not None: - self.additional_methods = {} + self.label_callbacks = {} for method_name, method_args in additional_labels.items(): - # the key of the additional_methods is the label of the method (total_charge, partial_charges, etc.) + # the key of the label_callbacks is the label of the method (total_charge, partial_charges, etc.) # this will be useful for static shapes processing if needed - self.additional_methods[method_args["label"]] = { + self.label_callbacks[method_args["label"]] = { "name": method_name, "method": initialize_additional_method(method_name, method_args), } @@ -217,7 +217,7 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype, - self.additional_methods, + self.label_callbacks, ) ) self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype) @@ -265,9 +265,9 @@ def forward( # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom zp = z - if self.additional_methods is not None: + if self.label_callbacks is not None: assert extra_args is not None - for label in self.additional_methods.keys(): + for label in self.label_callbacks.keys(): assert ( label in extra_args ), f"TensorNet expects {label} to be provided as part of extra_args" @@ -462,7 +462,7 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype=torch.float32, - additional_methods=None, + label_callbacks=None, ): super(Interaction, self).__init__() @@ -486,7 +486,7 @@ def __init__( ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group - self.additional_methods = additional_methods + self.label_callbacks = label_callbacks self.reset_parameters() @@ -495,9 +495,9 @@ def reset_parameters(self): linear.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() - if self.additional_methods is not None: - for method in self.additional_methods: - self.additional_methods[method]["method"].reset_parameters() + if self.label_callbacks is not None: + for method in self.label_callbacks: + self.label_callbacks[method]["method"].reset_parameters() def forward( self, @@ -531,8 +531,8 @@ def forward( msg = Im + Am + Sm prefactor = torch.tensor([1], device=X.device, dtype=X.dtype) - if self.additional_methods is not None and extra_args is not None: - tensorq_labels = [callback['method'](extra_args[label]) for label, callback in self.additional_methods.items() if callback['name'] == 'tensornet_q'] + if self.label_callbacks is not None and extra_args is not None: + tensorq_labels = [callback['method'](extra_args[label]) for label, callback in self.label_callbacks.items() if callback['name'] == 'tensornet_q'] if len(tensorq_labels) > 0: prefactor = prefactor + torch.prod(torch.stack(tensorq_labels), dim=0) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 10852ef38..2dd2a4860 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -140,7 +140,7 @@ def __init__( self.max_z = max_z self.dtype = dtype self.additional_labels = additional_labels - self.additional_methods = None + self.label_callbacks = None assert additional_labels is None, "equivariant-transformer does not support this feature" act_class = act_class_mapping[activation] diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 72c8f1502..fdad5bef2 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -143,7 +143,7 @@ def __init__( self.max_z = max_z self.aggr = aggr self.additional_labels = additional_labels - self.additional_methods = None + self.label_callbacks = None assert additional_labels is None, "graph-network does not support this feature" act_class = act_class_mapping[activation] diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 74290c332..028960e6b 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -130,7 +130,7 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.additional_labels = additional_labels - self.additional_methods = None + self.label_callbacks = None assert additional_labels is None, "transformer does not support this feature" act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 2ff26f15b..4e825e94c 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -33,7 +33,7 @@ def __init__(self, model): super().__init__() self.model = model - self.additional_methods = None + self.label_callbacks = None self.neighbors = CFConvNeighbors(self.model.cutoff_upper) offset = self.model.distance_expansion.offset From 6b72a00b51e62dd699729af1bc11cec75f6963fd Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 16:32:11 +0100 Subject: [PATCH 75/82] rename to additional_labels_handler --- torchmdnet/models/tensornet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 49fc6af14..f0cf29ec3 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -64,9 +64,9 @@ def tensor_norm(tensor): return (tensor**2).sum((-2, -1)) -def initialize_additional_method(method, args): - """Initialize additional methods to be used by the model. The additional methods are used to handle the extra_args provided to the model - using the additional_labels argument.""" +def additional_labels_handler(method, args): + """ Handler for additional labels. It returns the method to be used for the specific additional label + and the parameters to initialize it.""" if method == "tensornet_q": return TensornetQ(args["init_value"], args["label"], args["learnable"]) else: @@ -189,7 +189,7 @@ def __init__( # this will be useful for static shapes processing if needed self.label_callbacks[method_args["label"]] = { "name": method_name, - "method": initialize_additional_method(method_name, method_args), + "method": additional_labels_handler(method_name, method_args), } act_class = act_class_mapping[activation] From fc58fb3ae075eea35dd05db04de990cbdb3a17f3 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 16:48:19 +0100 Subject: [PATCH 76/82] get torch.jit compatibility, comprehension ifs are not supported yet --- torchmdnet/models/tensornet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index f0cf29ec3..fb74d55d8 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -532,7 +532,10 @@ def forward( prefactor = torch.tensor([1], device=X.device, dtype=X.dtype) if self.label_callbacks is not None and extra_args is not None: - tensorq_labels = [callback['method'](extra_args[label]) for label, callback in self.label_callbacks.items() if callback['name'] == 'tensornet_q'] + tensorq_labels = [] + for label, callback in self.label_callbacks.items(): + if callback['name'] == 'tensornet_q': + tensorq_labels.append(callback['method'](extra_args[label])) if len(tensorq_labels) > 0: prefactor = prefactor + torch.prod(torch.stack(tensorq_labels), dim=0) From 0ac63d9bee47c99b904e2d4555377ed588da918b Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 16:52:37 +0100 Subject: [PATCH 77/82] remove unused import --- torchmdnet/models/torchmd_et.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 2dd2a4860..019a51382 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -13,7 +13,6 @@ act_class_mapping, scatter, ) -from torchmdnet.utils import deprecated_class class TorchMD_ET(nn.Module): r"""Equivariant Transformer's architecture. From From 8a4a1814f8bdb0926d636c72d227698af4edd443 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 17:42:49 +0100 Subject: [PATCH 78/82] use warning instead of assertion --- torchmdnet/models/torchmd_et.py | 3 ++- torchmdnet/models/torchmd_gn.py | 3 ++- torchmdnet/models/torchmd_t.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 019a51382..1e12d459a 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -140,7 +140,8 @@ def __init__( self.dtype = dtype self.additional_labels = additional_labels self.label_callbacks = None - assert additional_labels is None, "equivariant-transformer does not support this feature" + if additional_labels is not None: + Warning("Found additional_labels, equivariant-transformer still does not support additional labels. Ignoring them.") act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index fdad5bef2..209535db9 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -144,7 +144,8 @@ def __init__( self.aggr = aggr self.additional_labels = additional_labels self.label_callbacks = None - assert additional_labels is None, "graph-network does not support this feature" + if additional_labels is not None: + Warning("Found additional_labels, graph-network still does not support additional labels. Ignoring them.") act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index 028960e6b..f397642b7 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -131,7 +131,8 @@ def __init__( self.max_z = max_z self.additional_labels = additional_labels self.label_callbacks = None - assert additional_labels is None, "transformer does not support this feature" + if additional_labels is not None: + Warning("Found additional_labels, transformer still does not support additional labels. Ignoring them.") act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] From fdf12aa3d86a0ae54bd6c4b6cf5628a1a36907f0 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 17:44:12 +0100 Subject: [PATCH 79/82] more readable format --- torchmdnet/models/tensornet.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index fb74d55d8..9b1b6fe95 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -266,11 +266,9 @@ def forward( zp = z if self.label_callbacks is not None: - assert extra_args is not None + assert extra_args is not None, "TensorNet expects extra_args to be provided when additional_labels are used" for label in self.label_callbacks.keys(): - assert ( - label in extra_args - ), f"TensorNet expects {label} to be provided as part of extra_args" + assert (label in extra_args), f"TensorNet expects {label} to be provided as part of extra_args" if extra_args[label].shape != z.shape: extra_args[label] = extra_args[label][batch] if self.static_shapes: @@ -285,7 +283,7 @@ def forward( ), dim=0, ) - + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) From 580cd34ae791194808fed5324e40525c23fdb896 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 18 Mar 2024 17:47:45 +0100 Subject: [PATCH 80/82] update test_model considering extra_args will be always passed to the forward --- tests/test_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 615cc561a..3c2c8f0b6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -26,10 +26,11 @@ def test_forward(model_name, use_batch, use_extra_args, precision, additional_la pos = pos.to(dtype=dtype_mapping[precision]) model = create_model(load_example_args(model_name, prior_model=None, precision=precision, additional_labels=additional_labels)) batch = batch if use_batch else None - if use_extra_args: - model(z, pos, batch=batch, extra_args={'total_charge': torch.zeros_like(z)}) - else: + if not use_extra_args and additional_labels is None: model(z, pos, batch=batch) + else: + model(z, pos, batch=batch, extra_args={'total_charge': torch.zeros_like(z)}) + @mark.parametrize("model_name", models.__all_models__) From d4832306e63a832f41660cded9bc2c93765705d2 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 19 Mar 2024 11:53:07 +0100 Subject: [PATCH 81/82] move prefactor and tensornet_q operation from the interacton layer to tensornet nn.module --- torchmdnet/models/tensornet.py | 54 +++++++++++++--------------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 9b1b6fe95..de6a42101 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -191,7 +191,10 @@ def __init__( "name": method_name, "method": additional_labels_handler(method_name, method_args), } - + self.tensorq_labels = None + if self.label_callbacks is not None: + self.tensorq_labels = [label for label, callback in self.label_callbacks.items() if callback['name'] == 'tensornet_q'] + act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -217,7 +220,6 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype, - self.label_callbacks, ) ) self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype) @@ -247,6 +249,9 @@ def reset_parameters(self): layer.reset_parameters() self.linear.reset_parameters() self.out_norm.reset_parameters() + if self.label_callbacks is not None: + for callback in self.label_callbacks: + self.label_callbacks[callback]["method"].reset_parameters() def forward( self, @@ -271,19 +276,9 @@ def forward( assert (label in extra_args), f"TensorNet expects {label} to be provided as part of extra_args" if extra_args[label].shape != z.shape: extra_args[label] = extra_args[label][batch] - if self.static_shapes: - extra_args[label] = torch.cat( - ( - extra_args[label], - torch.zeros( - 1, - device=extra_args[label].device, - dtype=extra_args[label].dtype, - ), - ), - dim=0, - ) - + if self.static_shapes: + extra_args[label] = torch.cat((extra_args[label], torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) @@ -300,9 +295,17 @@ def forward( # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) + + prefactor = torch.ones(1, device=z.device, dtype=z.dtype) + if self.tensorq_labels is not None: + for label in self.tensorq_labels: + prefactor = prefactor * self.label_callbacks[label]["method"](extra_args[label]) + prefactor += 1 + X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, extra_args) + X = layer(X, edge_index, edge_weight, edge_attr, prefactor) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -460,7 +463,6 @@ def __init__( cutoff_upper, equivariance_invariance_group, dtype=torch.float32, - label_callbacks=None, ): super(Interaction, self).__init__() @@ -484,7 +486,6 @@ def __init__( ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group - self.label_callbacks = label_callbacks self.reset_parameters() @@ -493,9 +494,6 @@ def reset_parameters(self): linear.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() - if self.label_callbacks is not None: - for method in self.label_callbacks: - self.label_callbacks[method]["method"].reset_parameters() def forward( self, @@ -503,7 +501,7 @@ def forward( edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - extra_args: Optional[Dict[str, Tensor]] = None, + prefactor: Tensor, ) -> Tensor: C = self.cutoff(edge_weight) for linear_scalar in self.linears_scalar: @@ -527,15 +525,6 @@ def forward( edge_index, edge_attr[..., 2, None, None], S, X.shape[0] ) msg = Im + Am + Sm - - prefactor = torch.tensor([1], device=X.device, dtype=X.dtype) - if self.label_callbacks is not None and extra_args is not None: - tensorq_labels = [] - for label, callback in self.label_callbacks.items(): - if callback['name'] == 'tensornet_q': - tensorq_labels.append(callback['method'](extra_args[label])) - if len(tensorq_labels) > 0: - prefactor = prefactor + torch.prod(torch.stack(tensorq_labels), dim=0) if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) @@ -571,5 +560,4 @@ def forward(self, X): return self.prmtr * X[..., None, None, None] def reset_parameters(self): - if self.learnable: - self.prmtr = nn.Parameter(torch.tensor(self.init_value)) + self.prmtr.data = torch.tensor(self.init_value) From 9e546dd064332d4fe1aeee7246e8b301f166b88c Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 19 Mar 2024 12:34:33 +0100 Subject: [PATCH 82/82] fix typo --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 3c2c8f0b6..f3e09d886 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -121,7 +121,7 @@ def test_cuda_graph_compatible(model_name): "prior_model": None, "atom_filter": -1, "derivative": True, - "check_error": False, + "check_errors": False, "static_shapes": True, "output_model": "Scalar", "reduce_op": "sum",