diff --git a/tests/test_examples.py b/tests/test_examples.py index 8cd5155e6..30c90ab83 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -5,6 +5,7 @@ from pytest import mark import yaml import glob +import torch from os.path import dirname, join from torchmdnet.models.model import create_model from torchmdnet import priors @@ -27,4 +28,4 @@ def test_example_yamls(fname): z, pos, batch = create_example_batch() model(z, pos, batch) - model(z, pos, batch, q=None, s=None) + model(z, pos, batch, extra_args={"total_charge": torch.zeros_like(z)}) diff --git a/tests/test_model.py b/tests/test_model.py index f606559ef..21576b29d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,19 +18,19 @@ @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("use_batch", [True, False]) -@mark.parametrize("explicit_q_s", [True, False]) +@mark.parametrize("use_extra_args", [True, False]) @mark.parametrize("precision", [32, 64]) -def test_forward(model_name, use_batch, explicit_q_s, precision): +@mark.parametrize("additional_labels", [None, {"tensornet_q": {"label": "total_charge", 'learnable': False, 'init_value': 0.1}}]) +def test_forward(model_name, use_batch, use_extra_args, precision, additional_labels): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) - model = create_model( - load_example_args(model_name, prior_model=None, precision=precision) - ) + model = create_model(load_example_args(model_name, prior_model=None, precision=precision, additional_labels=additional_labels)) batch = batch if use_batch else None - if explicit_q_s: - model(z, pos, batch=batch, q=None, s=None) - else: + if not use_extra_args and additional_labels is None: model(z, pos, batch=batch) + else: + model(z, pos, batch=batch, extra_args={'total_charge': torch.zeros_like(z)}) + @mark.parametrize("model_name", models.__all_models__) @@ -137,7 +137,6 @@ def test_cuda_graph_compatible(model_name): "output_model": "Scalar", "reduce_op": "sum", "precision": 32, - } model = create_model(args).to(device="cuda") model.eval() z = z.to("cuda") diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 8f8aff2aa..05b40e3f0 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -19,7 +19,7 @@ def test_atom_filter(remove_threshold, model_name): model = AtomFilter(model, remove_threshold) z, pos, batch = create_example_batch(n_atoms=100) - x, v, z, pos, batch = model(z, pos, batch, None, None) + x, v, z, pos, batch = model(z, pos, batch, None) assert (z > remove_threshold).all(), ( f"Lowest updated atomic number is {z.min()} but " diff --git a/tests/utils.py b/tests/utils.py index ef8bcddb9..b533a2e14 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,6 +26,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs args["box_vecs"] = None if "remove_ref_energy" not in args: args["remove_ref_energy"] = False + if "additional_labels" not in args: + args["additional_labels"] = None for key, val in kwargs.items(): assert key in args, f"Broken test! Unknown key '{key}'." args[key] = val diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e8c59785d..d2b57638d 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -146,7 +146,7 @@ def __init__( transform, pre_transform, pre_filter, - properties=("y", "neg_dy", "q", "pq", "dp"), + properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"), ) @property @@ -188,14 +188,14 @@ def _load_confs_1_0(mol, n_atoms): assert neg_dy.shape == pos.shape assert conf["partial_charges"].attrs["units"] == "e" - pq = pt.tensor(conf["partial_charges"][:], dtype=pt.float32) - assert pq.shape == (n_atoms,) + partial_charges = pt.tensor(conf["partial_charges"][:], dtype=pt.float32) + assert partial_charges.shape == (n_atoms,) assert conf["dipole_moment"].attrs["units"] == "e*Å" - dp = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32) - assert dp.shape == (3,) + dipole_moment = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32) + assert dipole_moment.shape == (3,) - yield pos, y, neg_dy, pq, dp + yield pos, y, neg_dy, partial_charges, dipole_moment @staticmethod def _load_confs_2_0(mol, n_atoms): @@ -213,19 +213,19 @@ def _load_confs_2_0(mol, n_atoms): assert all_neg_dy.shape == all_pos.shape assert mol["partial_charges"].attrs["units"] == "e" - all_pq = pt.tensor(mol["partial_charges"][...], dtype=pt.float32) - assert all_pq.shape == (n_confs, n_atoms) + all_partial_charges = pt.tensor(mol["partial_charges"][...], dtype=pt.float32) + assert all_partial_charges.shape == (n_confs, n_atoms) assert mol["dipole_moments"].attrs["units"] == "e*Å" - all_dp = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32) - assert all_dp.shape == (n_confs, 3) + all_dipole_moment = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32) + assert all_dipole_moment.shape == (n_confs, 3) - for pos, y, neg_dy, pq, dp in zip(all_pos, all_y, all_neg_dy, all_pq, all_dp): + for pos, y, neg_dy, partial_charges, dipole_moment in zip(all_pos, all_y, all_neg_dy, all_partial_charges, all_dipole_moment): # Skip failed calculations if y.isnan(): continue - yield pos, y, neg_dy, pq, dp + yield pos, y, neg_dy, partial_charges, dipole_moment def sample_iter(self, mol_ids=False): assert self.subsample_molecules > 0 @@ -261,9 +261,9 @@ def sample_iter(self, mol_ids=False): z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) fq = pt.tensor(mol["formal_charges"], dtype=pt.long) - q = fq.sum() + total_charge = fq.sum() - for i_conf, (pos, y, neg_dy, pq, dp) in enumerate( + for i_conf, (pos, y, neg_dy, partial_charges, dipole_moment) in enumerate( load_confs(mol, n_atoms=len(z)) ): # Skip samples with large forces @@ -273,7 +273,7 @@ def sample_iter(self, mol_ids=False): # Create a sample args = dict( - z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp + z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=total_charge, partial_charges=partial_charges, dipole_moment=dipole_moment ) if mol_ids: args["mol_id"] = mol_id diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index c647b7d2a..d9cb895d4 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -57,6 +57,14 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs): self.fields.append( ("partial_charges", "partial_charges", torch.float32) ) + if "total_charge" in group: + self.fields.append( + ("total_charge", "total_charge", torch.float32) + ) + if "spin" in group: + self.fields.append( + ("spin", "spin", torch.float32) + ) assert ("energy" in group) or ( "forces" in group ), "Each group must contain at least energies or forces" diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index a97b54be8..bb3576ed6 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -17,9 +17,9 @@ class MemmappedDataset(Dataset): - :obj:`pos`: Positions of the atoms. - :obj:`y`: Energy of the conformation. - :obj:`neg_dy`: Forces on the atoms. - - :obj:`q`: Total charge of the conformation. - - :obj:`pq`: Partial charges of the atoms. - - :obj:`dp`: Dipole moment of the conformation. + - :obj:`total_charge`: Total charge of the conformation. + - :obj:`partial_charges`: Partial charges of the atoms. + - :obj:`dipole_moment`: Dipole moment of the conformation. The data is stored in the following files: @@ -28,9 +28,9 @@ class MemmappedDataset(Dataset): - :obj:`name.pos.mmap`: Positions of all the atoms. - :obj:`name.y.mmap`: Energy of each conformation. - :obj:`name.neg_dy.mmap`: Forces on all the atoms. - - :obj:`name.q.mmap`: Total charge of each conformation. - - :obj:`name.pq.mmap`: Partial charges of all the atoms. - - :obj:`name.dp.mmap`: Dipole moment of each conformation. + - :obj:`name.total_charge.mmap`: Total charge of each conformation. + - :obj:`name.partial_charges.mmap`: Partial charges of all the atoms. + - :obj:`name.dipole_moment.mmap`: Dipole moment of each conformation. Args: root (str): Root directory where the dataset should be stored. @@ -45,8 +45,8 @@ class MemmappedDataset(Dataset): indicating whether the data object should be included in the final dataset. properties (tuple of str, optional): The properties to include in the - dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`q`, - :obj:`pq`, and :obj:`dp`. + dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`total_charge`, + :obj:`partial_charges`, and :obj:`dipole_moment`. """ def __init__( @@ -55,7 +55,7 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - properties=("y", "neg_dy", "q", "pq", "dp"), + properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"), ): self.name = self.__class__.__name__ self.properties = properties @@ -76,13 +76,13 @@ def __init__( self.neg_dy_mm = np.memmap( fnames["neg_dy"], mode="r", dtype=np.float32, shape=(num_all_atoms, 3) ) - if "q" in self.properties: - self.q_mm = np.memmap(fnames["q"], mode="r", dtype=np.int8) - if "pq" in self.properties: - self.pq_mm = np.memmap(fnames["pq"], mode="r", dtype=np.float32) - if "dp" in self.properties: + if "total_charge" in self.properties: + self.q_mm = np.memmap(fnames["total_charge"], mode="r", dtype=np.int8) + if "partial_charges" in self.properties: + self.pq_mm = np.memmap(fnames["partial_charges"], mode="r", dtype=np.float32) + if "dipole_moment" in self.properties: self.dp_mm = np.memmap( - fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3) + fnames["dipole_moment"], mode="r", dtype=np.float32, shape=(num_all_confs, 3) ) assert self.idx_mm[0] == 0 @@ -151,20 +151,20 @@ def process(self): dtype=np.float32, shape=(num_all_atoms, 3), ) - if "q" in self.properties: + if "total_charge" in self.properties: q_mm = np.memmap( - fnames["q"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs + fnames["total_charge"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs ) - if "pq" in self.properties: + if "partial_charges" in self.properties: pq_mm = np.memmap( - fnames["pq"] + ".tmp", + fnames["partial_charges"] + ".tmp", mode="w+", dtype=np.float32, shape=num_all_atoms, ) - if "dp" in self.properties: + if "dipole_moment" in self.properties: dp_mm = np.memmap( - fnames["dp"] + ".tmp", + fnames["dipole_moment"] + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_confs, 3), @@ -182,12 +182,12 @@ def process(self): y_mm[i_conf] = data.y if "neg_dy" in self.properties: neg_dy_mm[i_atom:i_next_atom] = data.neg_dy - if "q" in self.properties: - q_mm[i_conf] = data.q.to(pt.int8) - if "pq" in self.properties: - pq_mm[i_atom:i_next_atom] = data.pq - if "dp" in self.properties: - dp_mm[i_conf] = data.dp + if "total_charge" in self.properties: + q_mm[i_conf] = data.total_charge.to(pt.int8) + if "partial_charges" in self.properties: + pq_mm[i_atom:i_next_atom] = data.partial_charges + if "dipole_moment" in self.properties: + dp_mm[i_conf] = data.dipole_moment i_atom = i_next_atom idx_mm[-1] = num_all_atoms @@ -200,11 +200,11 @@ def process(self): y_mm.flush() if "neg_dy" in self.properties: neg_dy_mm.flush() - if "q" in self.properties: + if "total_charge" in self.properties: q_mm.flush() - if "pq" in self.properties: + if "partial_charges" in self.properties: pq_mm.flush() - if "dp" in self.properties: + if "dipole_moment" in self.properties: dp_mm.flush() os.rename(idx_mm.filename, fnames["idx"]) @@ -214,12 +214,12 @@ def process(self): os.rename(y_mm.filename, fnames["y"]) if "neg_dy" in self.properties: os.rename(neg_dy_mm.filename, fnames["neg_dy"]) - if "q" in self.properties: - os.rename(q_mm.filename, fnames["q"]) - if "pq" in self.properties: - os.rename(pq_mm.filename, fnames["pq"]) - if "dp" in self.properties: - os.rename(dp_mm.filename, fnames["dp"]) + if "total_charge" in self.properties: + os.rename(q_mm.filename, fnames["total_charge"]) + if "partial_charges" in self.properties: + os.rename(pq_mm.filename, fnames["partial_charges"]) + if "dipole_moment" in self.properties: + os.rename(dp_mm.filename, fnames["dipole_moment"]) def len(self): return len(self.idx_mm) - 1 @@ -233,9 +233,9 @@ def get(self, idx): - :obj:`pos`: Positions of the atoms. - :obj:`y`: Formation energy of the molecule. - :obj:`neg_dy`: Forces on the atoms. - - :obj:`q`: Total charge of the molecule. - - :obj:`pq`: Partial charges of the atoms. - - :obj:`dp`: Dipole moment of the molecule. + - :obj:`total_charge`: Total charge of the molecule. + - :obj:`partial_charges`: Partial charges of the atoms. + - :obj:`dipole_moment`: Dipole moment of the molecule. Args: idx (int): Index of the data object. @@ -252,10 +252,10 @@ def get(self, idx): props["y"] = pt.tensor(self.y_mm[idx]).view(1, 1) if "neg_dy" in self.properties: props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms]) - if "q" in self.properties: - props["q"] = pt.tensor(self.q_mm[idx], dtype=pt.long) - if "pq" in self.properties: - props["pq"] = pt.tensor(self.pq_mm[atoms]) - if "dp" in self.properties: - props["dp"] = pt.tensor(self.dp_mm[idx]) + if "total_charge" in self.properties: + props["total_charge"] = pt.tensor(self.q_mm[idx], dtype=pt.long) + if "partial_charges" in self.properties: + props["partial_charges"] = pt.tensor(self.pq_mm[atoms]) + if "dipole_moment" in self.properties: + props["dipole_moment"] = pt.tensor(self.dp_mm[idx]) return Data(z=z, pos=pos, **props) diff --git a/torchmdnet/datasets/qm9q.py b/torchmdnet/datasets/qm9q.py index 63a262a30..abaf677c8 100644 --- a/torchmdnet/datasets/qm9q.py +++ b/torchmdnet/datasets/qm9q.py @@ -46,7 +46,7 @@ def __init__( transform, pre_transform, pre_filter, - properties=("y", "neg_dy", "q", "pq", "dp"), + properties=("y", "neg_dy", "total_charge", "partial_charges", "dipole_moment"), ) @property @@ -150,7 +150,7 @@ def sample_iter(self, mol_ids=False): # Create a sample args = dict( - z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp + z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, total_charge=q, partial_charges=pq, dipole_moment=dp ) if mol_ids: args["mol_id"] = mol_id diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 913693043..c6a06f394 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -42,6 +42,14 @@ def create_model(args, prior_model=None, mean=None, std=None): if "vector_cutoff" not in args: args["vector_cutoff"] = False + # additional labels for the representation model + additional_labels = args.get("additional_labels") + if additional_labels is not None and not isinstance(additional_labels, dict): + additional_labels = None + warnings.warn( + "Additional labels should be a dictionary. Ignoring additional labels." + ) + shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -60,6 +68,7 @@ def create_model(args, prior_model=None, mean=None, std=None): else None ), dtype=dtype, + additional_labels=additional_labels, ) # representation network @@ -322,10 +331,10 @@ class TorchMD_Net(nn.Module): Parameters ---------- representation_model : nn.Module - A model that takes as input the atomic numbers, positions, batch indices, and optionally - charges and spins. It must return a tuple of the form (x, v, z, pos, batch), where x - are the atom features, v are the vector features (if any), z are the atomic numbers, - pos are the positions, and batch are the batch indices. See TorchMD_ET for more details. + A model that takes as input the atomic numbers, positions, batch indices and extra_args(optional). It must + return a tuple of the form (x, v, z, pos, batch), where x are the atom features, v are the vector features + (if any), z are the atomic numbers, pos are the positions, and batch are the batch indices. See TorchMD_ET + for more details. output_model : nn.Module A model that takes as input the atom features, vector features (if any), atomic numbers, positions, and batch indices. See OutputModel for more details. @@ -395,8 +404,6 @@ def forward( pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Tensor]: """ @@ -427,9 +434,7 @@ def forward( The vectors defining the periodic box. This must have shape `(3, 3)`, where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. If this is omitted, periodic boundary conditions are not applied. - q (Tensor, optional): Atomic charges in the molecule. Shape: (N,). - s (Tensor, optional): Atomic spins in the molecule. Shape: (N,). - extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model and to the representation model. Returns: Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise. @@ -439,9 +444,14 @@ def forward( if self.derivative: pos.requires_grad_(True) + # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( - z, pos, batch, box=box, q=q, s=s + z, + pos, + batch, + box=box, + extra_args=extra_args, ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e7..de6a42101 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import torch -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict from torch import Tensor, nn from torchmdnet.models.utils import ( CosineCutoff, @@ -64,6 +64,15 @@ def tensor_norm(tensor): return (tensor**2).sum((-2, -1)) +def additional_labels_handler(method, args): + """ Handler for additional labels. It returns the method to be used for the specific additional label + and the parameters to initialize it.""" + if method == "tensornet_q": + return TensornetQ(args["init_value"], args["label"], args["learnable"]) + else: + raise NotImplementedError(f"Method {method} not implemented") + + class TensorNet(nn.Module): r"""TensorNet's architecture. From TensorNet: Cartesian Tensor Representations for Efficient Learning of Molecular Potentials; G. Simeon and G. de Fabritiis. @@ -120,6 +129,12 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} + (default: :obj:`None`) """ def __init__( @@ -139,6 +154,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, + additional_labels=None, ): super(TensorNet, self).__init__() @@ -163,6 +179,22 @@ def __init__( self.activation = activation self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper + self.additional_labels = additional_labels + self.label_callbacks = None + + if additional_labels is not None: + self.label_callbacks = {} + for method_name, method_args in additional_labels.items(): + # the key of the label_callbacks is the label of the method (total_charge, partial_charges, etc.) + # this will be useful for static shapes processing if needed + self.label_callbacks[method_args["label"]] = { + "name": method_name, + "method": additional_labels_handler(method_name, method_args), + } + self.tensorq_labels = None + if self.label_callbacks is not None: + self.tensorq_labels = [label for label, callback in self.label_callbacks.items() if callback['name'] == 'tensornet_q'] + act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -173,11 +205,9 @@ def __init__( act_class, cutoff_lower, cutoff_upper, - trainable_rbf, max_z, dtype, ) - self.layers = nn.ModuleList() if num_layers != 0: for _ in range(num_layers): @@ -219,6 +249,9 @@ def reset_parameters(self): layer.reset_parameters() self.linear.reset_parameters() self.out_norm.reset_parameters() + if self.label_callbacks is not None: + for callback in self.label_callbacks: + self.label_callbacks[callback]["method"].reset_parameters() def forward( self, @@ -226,8 +259,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) @@ -236,16 +268,21 @@ def forward( edge_vec is not None ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom - # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q - if q is None: - q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - else: - q = q[batch] zp = z + + if self.label_callbacks is not None: + assert extra_args is not None, "TensorNet expects extra_args to be provided when additional_labels are used" + for label in self.label_callbacks.keys(): + assert (label in extra_args), f"TensorNet expects {label} to be provided as part of extra_args" + if extra_args[label].shape != z.shape: + extra_args[label] = extra_args[label][batch] + if self.static_shapes: + extra_args[label] = torch.cat((extra_args[label], torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) + if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) + # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -258,9 +295,17 @@ def forward( # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) + + prefactor = torch.ones(1, device=z.device, dtype=z.dtype) + if self.tensorq_labels is not None: + for label in self.tensorq_labels: + prefactor = prefactor * self.label_callbacks[label]["method"](extra_args[label]) + prefactor += 1 + X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, q) + X = layer(X, edge_index, edge_weight, edge_attr, prefactor) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -284,7 +329,6 @@ def __init__( activation, cutoff_lower, cutoff_upper, - trainable_rbf=False, max_z=128, dtype=torch.float32, ): @@ -442,6 +486,7 @@ def __init__( ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group + self.reset_parameters() def reset_parameters(self): @@ -456,7 +501,7 @@ def forward( edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - q: Tensor, + prefactor: Tensor, ) -> Tensor: C = self.cutoff(edge_weight) for linear_scalar in self.linears_scalar: @@ -480,18 +525,39 @@ def forward( edge_index, edge_attr[..., 2, None, None], S, X.shape[0] ) msg = Im + Am + Sm + if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B)) + I, A, S = decompose_tensor(prefactor * (A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(2 * B) + I, A, S = decompose_tensor(prefactor * 2 * B) normp1 = (tensor_norm(I + A + S) + 1)[..., None, None] I, A, S = I / normp1, A / normp1, S / normp1 I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute( + 0, 3, 1, 2 + ) # shape: (natoms, hidden_channels, 3, 3) dX = I + A + S - X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2) + X = X + dX + (prefactor) * torch.matrix_power(dX, 2) return X + + +class TensornetQ(nn.Module): + def __init__(self, init_value, additional_label="total_charge", learnable=False): + super().__init__() + self.prmtr = nn.Parameter(torch.tensor(init_value), requires_grad=learnable) + self.learnable = learnable + self.init_value = init_value + self.allowed_labels = ["total_charge", "partial_charges"] + assert ( + additional_label in self.allowed_labels + ), f"Label {additional_label} not allowed for this method" + + def forward(self, X): + return self.prmtr * X[..., None, None, None] + + def reset_parameters(self): + self.prmtr.data = torch.tensor(self.init_value) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 5ff168d54..1e12d459a 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -13,7 +13,6 @@ act_class_mapping, scatter, ) -from torchmdnet.utils import deprecated_class class TorchMD_ET(nn.Module): r"""Equivariant Transformer's architecture. From @@ -79,7 +78,12 @@ class TorchMD_ET(nn.Module): (default: :obj:`False`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} + (default: :obj:`None`) """ def __init__( @@ -102,6 +106,7 @@ def __init__( box_vecs=None, vector_cutoff=False, dtype=torch.float32, + additional_labels=None, ): super(TorchMD_ET, self).__init__() @@ -133,7 +138,10 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.dtype = dtype - + self.additional_labels = additional_labels + self.label_callbacks = None + if additional_labels is not None: + Warning("Found additional_labels, equivariant-transformer still does not support additional labels. Ignoring them.") act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) @@ -194,8 +202,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 31d68ae03..209535db9 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -86,6 +86,11 @@ class TorchMD_GN(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} """ @@ -107,6 +112,7 @@ def __init__( aggr="add", dtype=torch.float32, box_vecs=None, + additional_labels=None ): super(TorchMD_GN, self).__init__() @@ -136,7 +142,10 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.aggr = aggr - + self.additional_labels = additional_labels + self.label_callbacks = None + if additional_labels is not None: + Warning("Found additional_labels, graph-network still does not support additional labels. Ignoring them.") act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) @@ -196,8 +205,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - s: Optional[Tensor] = None, - q: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index c11efc080..f397642b7 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -76,7 +76,11 @@ class TorchMD_T(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + additional_labels (Dict[str, Any], optional): Define the additional method to be used by the model, and the parameters to initialize it. + example: + additional_labels = {method_name1: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + method_name2: {label_name: 'extra_arg_label', 'method_prm1': method_prm1, 'method_prm2': method_prm2}, + ...} """ def __init__( @@ -98,6 +102,7 @@ def __init__( max_num_neighbors=32, dtype=torch.float, box_vecs=None, + additional_labels=None ): super(TorchMD_T, self).__init__() @@ -124,7 +129,10 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z - + self.additional_labels = additional_labels + self.label_callbacks = None + if additional_labels is not None: + Warning("Found additional_labels, transformer still does not support additional labels. Ignoring them.") act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] @@ -190,8 +198,7 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - s: Optional[Tensor] = None, - q: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/wrappers.py b/torchmdnet/models/wrappers.py index 444805e06..ce987ce01 100644 --- a/torchmdnet/models/wrappers.py +++ b/torchmdnet/models/wrappers.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) from abc import abstractmethod, ABCMeta -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict from torch import nn, Tensor @@ -43,10 +43,9 @@ def forward( z: Tensor, pos: Tensor, batch: Tensor, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - x, v, z, pos, batch = self.model(z, pos, batch=batch, q=q, s=s) + x, v, z, pos, batch = self.model(z, pos, batch=batch, extra_args=extra_args) n_samples = len(batch.unique()) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915e..a7b78893a 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -61,11 +61,6 @@ class LNNP(LightningModule): def __init__(self, hparams, prior_model=None, mean=None, std=None): super(LNNP, self).__init__() - if "charge" not in hparams: - hparams["charge"] = False - if "spin" not in hparams: - hparams["spin"] = False - self.save_hyperparameters(hparams) if self.hparams.load_model: @@ -119,11 +114,9 @@ def forward( pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor]]: - return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args) + return self.model(z, pos, batch=batch, box=box, extra_args=extra_args) def training_step(self, batch, batch_idx): return self.step(batch, [mse_loss], "train") @@ -199,7 +192,7 @@ def step(self, batch, loss_fn_list, stage): batch = self.data_transform(batch) with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): extra_args = batch.to_dict() - for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"): + for a in ("y", "neg_dy", "z", "pos", "batch", "box"): if a in extra_args: del extra_args[a] # TODO: the model doesn't necessarily need to return a derivative once @@ -209,8 +202,6 @@ def step(self, batch, loss_fn_list, stage): batch.pos, batch=batch.batch, box=batch.box if "box" in batch else None, - q=batch.q if self.hparams.charge else None, - s=batch.s if self.hparams.spin else None, extra_args=extra_args, ) if self.hparams.derivative and "y" not in batch: diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 0c7f56513..4e825e94c 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch as pt from NNPOps.CFConv import CFConv from NNPOps.CFConvNeighbors import CFConvNeighbors @@ -33,7 +33,7 @@ def __init__(self, model): super().__init__() self.model = model - + self.label_callbacks = None self.neighbors = CFConvNeighbors(self.model.cutoff_upper) offset = self.model.distance_expansion.offset @@ -56,8 +56,7 @@ def forward( pos: pt.Tensor, batch: pt.Tensor, box: Optional[pt.Tensor] = None, - q: Optional[pt.Tensor] = None, - s: Optional[pt.Tensor] = None, + extra_args: Optional[Dict[str, pt.Tensor]] = None, ) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]: assert pt.all(batch == 0) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45f..34a7825d5 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -60,6 +60,7 @@ def get_argparse(): parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') # dataset specific + parser.add_argument('--additional-labels', default=None, help='Additional labels to be passed to the model, must be a dict like {"method_name":{args}}') parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') parser.add_argument('--dataset-arg', default=None, help='Additional dataset arguments. Needs to be a dictionary.') @@ -77,8 +78,8 @@ def get_argparse(): parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*") # architectural args - parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') - parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.') + parser.add_argument('--charge', type=bool, default=False, help='DEPRECATED: This argument is no longer in use and is maintained only for retro-compatibility') + parser.add_argument('--spin', type=bool, default=False, help='DEPRECATED: This argument is no longer in use and is maintained only for retro-compatibility') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') @@ -139,7 +140,7 @@ def get_args(): args.inference_batch_size = args.batch_size os.makedirs(os.path.abspath(args.log_dir), exist_ok=True) - save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) + save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf", "charge", "spin"]) return args