From deb86fc2677afbac089f2603c717b7c9b09c3324 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 11:41:29 +0200 Subject: [PATCH 01/16] Add test using gradcheck --- tests/test_model.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index cdf4040d9..ac85e8ccd 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -5,7 +5,7 @@ import torch import pytorch_lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model +from torchmdnet.models.model import create_model, load_model from torchmdnet.models import output_modules from utils import load_example_args, create_example_batch @@ -60,7 +60,7 @@ def test_seed(model_name): output_modules.__all__, ) def test_forward_output(model_name, output_model, overwrite_reference=False): - pl.seed_everything(1234) + pl.seed_everynthing(1234) # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] @@ -101,3 +101,27 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): torch.testing.assert_allclose( deriv, expected[model_name][output_model]["deriv"] ) + +@mark.parametrize("model_name", models.__all__) +@mark.parametrize( + "output_model", + output_modules.__all__, +) +def test_gradients(model_name, output_model): + pl.seed_everything(1234) + if model_name != "equivariant-transformer": + pytest.skip("Gradients are not implemented for this model.") + # create model and sample batch + derivative = output_model in ["Scalar", "EquivariantScalar"] + args = load_example_args( + model_name, + remove_prior=True, + output_model=output_model, + derivative=derivative, + ) + model = create_model(args).to(torch.float64) + print(model) + z, pos, batch = create_example_batch(n_atoms=5) + pos.requires_grad = True + pos.to(torch.float64) + torch.autograd.gradcheck(model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3) From f85e7dc60012bc09f3466f262260ad72ce243dd6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 14:02:17 +0200 Subject: [PATCH 02/16] Add the dtype parameter and ensure every module respects it. --- examples/ET-QM9.yaml | 1 + tests/test_model.py | 30 +++++++++++------------ torchmdnet/models/model.py | 28 ++++++++++++++++++---- torchmdnet/models/output_modules.py | 29 ++++++++++++---------- torchmdnet/models/torchmd_et.py | 36 ++++++++++++++++------------ torchmdnet/models/torchmd_gn.py | 22 ++++++++++------- torchmdnet/models/torchmd_t.py | 26 +++++++++++--------- torchmdnet/models/utils.py | 37 ++++++++++++++++------------- torchmdnet/scripts/train.py | 5 ++-- 9 files changed, 129 insertions(+), 85 deletions(-) diff --git a/examples/ET-QM9.yaml b/examples/ET-QM9.yaml index ef9048578..24d4ba242 100644 --- a/examples/ET-QM9.yaml +++ b/examples/ET-QM9.yaml @@ -55,3 +55,4 @@ train_size: 110000 trainable_rbf: false val_size: 10000 weight_decay: 0.0 +dtype: float diff --git a/tests/test_model.py b/tests/test_model.py index ac85e8ccd..07ceb50cf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -14,9 +14,11 @@ @mark.parametrize("model_name", models.__all__) @mark.parametrize("use_batch", [True, False]) @mark.parametrize("explicit_q_s", [True, False]) -def test_forward(model_name, use_batch, explicit_q_s): +@mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_forward(model_name, use_batch, explicit_q_s, dtype): z, pos, batch = create_example_batch() - model = create_model(load_example_args(model_name, prior_model=None)) + pos = pos.to(dtype=dtype) + model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype)) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -60,7 +62,7 @@ def test_seed(model_name): output_modules.__all__, ) def test_forward_output(model_name, output_model, overwrite_reference=False): - pl.seed_everynthing(1234) + pl.seed_everything(1234) # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] @@ -103,14 +105,10 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ) @mark.parametrize("model_name", models.__all__) -@mark.parametrize( - "output_model", - output_modules.__all__, -) -def test_gradients(model_name, output_model): +def test_gradients(model_name): pl.seed_everything(1234) - if model_name != "equivariant-transformer": - pytest.skip("Gradients are not implemented for this model.") + dtype = torch.float64 + output_model = "Scalar" # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] args = load_example_args( @@ -118,10 +116,12 @@ def test_gradients(model_name, output_model): remove_prior=True, output_model=output_model, derivative=derivative, + dtype=dtype, ) - model = create_model(args).to(torch.float64) - print(model) + model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) - pos.requires_grad = True - pos.to(torch.float64) - torch.autograd.gradcheck(model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3) + pos.requires_grad_(True) + pos = pos.to(dtype) + torch.autograd.gradcheck( + model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 + ) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 508bf3fd6..be0cf72af 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -7,11 +7,26 @@ from pytorch_lightning.utilities import rank_zero_warn from torchmdnet.models import output_modules from torchmdnet.models.wrappers import AtomFilter +from torchmdnet.models.utils import dtype_mapping from torchmdnet import priors import warnings def create_model(args, prior_model=None, mean=None, std=None): + """Create a model from the given arguments. + See :func:`get_args` in scripts/train.py for a description of the arguments. + Parameters + ---------- + args (dict): Arguments for the model. + prior_model (nn.Module, optional): Prior model to use. Defaults to None. + mean (torch.Tensor, optional): Mean of the training data. Defaults to None. + std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None. + Returns + ------- + nn.Module: An instance of the TorchMD_Net model. + """ + #Use mapping if args["dtype"] is a string, else use args["dtype"] + args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"] shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -24,6 +39,7 @@ def create_model(args, prior_model=None, mean=None, std=None): cutoff_upper=args["cutoff_upper"], max_z=args["max_z"], max_num_neighbors=args["max_num_neighbors"], + dtype=args["dtype"] ) # representation network @@ -74,6 +90,7 @@ def create_model(args, prior_model=None, mean=None, std=None): args["embedding_dimension"], activation=args["activation"], reduce_op=args["reduce_op"], + dtype=args["dtype"], ) # combine representation and output network @@ -84,6 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None): mean=mean, std=std, derivative=args["derivative"], + dtype=args["dtype"], ) return model @@ -156,10 +174,11 @@ def __init__( mean=None, std=None, derivative=False, + dtype=torch.float32, ): super(TorchMD_Net, self).__init__() self.representation_model = representation_model - self.output_model = output_model + self.output_model = output_model.to(dtype=dtype) if not output_model.allow_prior_model and prior_model is not None: prior_model = None @@ -171,14 +190,14 @@ def __init__( ) if isinstance(prior_model, priors.base.BasePrior): prior_model = [prior_model] - self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model) + self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype) self.derivative = derivative mean = torch.scalar_tensor(0) if mean is None else mean - self.register_buffer("mean", mean) + self.register_buffer("mean", mean.to(dtype=dtype)) std = torch.scalar_tensor(1) if std is None else std - self.register_buffer("std", std) + self.register_buffer("std", std.to(dtype=dtype)) self.reset_parameters() @@ -247,6 +266,7 @@ def forward( )[0] if dy is None: raise RuntimeError("Autograd returned None for the force prediction.") + return y, -dy # TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) return y, None diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index 0bcec935f..f2d51a058 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -38,15 +38,16 @@ def __init__( activation="silu", allow_prior_model=True, reduce_op="sum", + dtype=torch.float ): super(Scalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op ) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2), + nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), act_class(), - nn.Linear(hidden_channels // 2, 1), + nn.Linear(hidden_channels // 2, 1, dtype=dtype), ) self.reset_parameters() @@ -68,6 +69,7 @@ def __init__( activation="silu", allow_prior_model=True, reduce_op="sum", + dtype=torch.float ): super(EquivariantScalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op @@ -79,8 +81,9 @@ def __init__( hidden_channels // 2, activation=activation, scalar_activation=True, + dtype=dtype ), - GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation), + GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation, dtype=dtype), ] ) @@ -98,11 +101,11 @@ def pre_reduce(self, x, v, z, pos, batch): class DipoleMoment(Scalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(DipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op + hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype ) - atomic_mass = torch.from_numpy(atomic_masses).float() + atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch): @@ -119,11 +122,11 @@ def post_reduce(self, x): class EquivariantDipoleMoment(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(EquivariantDipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op + hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype ) - atomic_mass = torch.from_numpy(atomic_masses).float() + atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) def pre_reduce(self, x, v, z, pos, batch): @@ -141,17 +144,17 @@ def post_reduce(self, x): class ElectronicSpatialExtent(OutputModel): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(ElectronicSpatialExtent, self).__init__( allow_prior_model=False, reduce_op=reduce_op ) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2), + nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), act_class(), - nn.Linear(hidden_channels // 2, 1), + nn.Linear(hidden_channels // 2, 1, dtype=dtype), ) - atomic_mass = torch.from_numpy(atomic_masses).float() + atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) self.reset_parameters() diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 69977be93..9f93028cb 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -11,7 +11,6 @@ act_class_mapping, ) - class TorchMD_ET(nn.Module): r"""The TorchMD equivariant Transformer architecture. @@ -67,6 +66,7 @@ def __init__( cutoff_upper=5.0, max_z=100, max_num_neighbors=32, + dtype=torch.float32, ): super(TorchMD_ET, self).__init__() @@ -97,10 +97,11 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z + self.dtype = dtype act_class = act_class_mapping[activation] - self.embedding = nn.Embedding(self.max_z, hidden_channels) + self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) self.distance = Distance( cutoff_lower, @@ -114,7 +115,7 @@ def __init__( ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z + hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype ).jittable() if neighbor_embedding else None @@ -131,10 +132,11 @@ def __init__( attn_activation, cutoff_lower, cutoff_upper, + dtype, ).jittable() self.attention_layers.append(layer) - self.out_norm = nn.LayerNorm(hidden_channels) + self.out_norm = nn.LayerNorm(hidden_channels, dtype=dtype) self.reset_parameters() @@ -147,6 +149,7 @@ def reset_parameters(self): attn.reset_parameters() self.out_norm.reset_parameters() + def forward( self, z: Tensor, @@ -159,6 +162,8 @@ def forward( x = self.embedding(z) edge_index, edge_weight, edge_vec = self.distance(pos, batch) + # This assert must be here to convince TorchScript that edge_vec is not None + # If you remove it TorchScript will complain down below that you cannot use an Optional[Tensor] assert ( edge_vec is not None ), "Distance module did not return directional information" @@ -170,7 +175,7 @@ def forward( if self.neighbor_embedding is not None: x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) - vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device, dtype=x.dtype) for attn in self.attention_layers: dx, dvec = attn(x, vec, edge_index, edge_weight, edge_attr, edge_vec) @@ -194,7 +199,8 @@ def __repr__(self): f"num_heads={self.num_heads}, " f"distance_influence={self.distance_influence}, " f"cutoff_lower={self.cutoff_lower}, " - f"cutoff_upper={self.cutoff_upper})" + f"cutoff_upper={self.cutoff_upper}), " + f"dtype={self.dtype}" ) @@ -209,6 +215,7 @@ def __init__( attn_activation, cutoff_lower, cutoff_upper, + dtype=torch.float32, ): super(EquivariantMultiHeadAttention, self).__init__(aggr="add", node_dim=0) assert hidden_channels % num_heads == 0, ( @@ -221,26 +228,25 @@ def __init__( self.num_heads = num_heads self.hidden_channels = hidden_channels self.head_dim = hidden_channels // num_heads - - self.layernorm = nn.LayerNorm(hidden_channels) + self.layernorm = nn.LayerNorm(hidden_channels, dtype=dtype) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) - self.q_proj = nn.Linear(hidden_channels, hidden_channels) - self.k_proj = nn.Linear(hidden_channels, hidden_channels) - self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3) - self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) + self.q_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.k_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3, dtype=dtype) + self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3, dtype=dtype) - self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False) + self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False, dtype=dtype) self.dk_proj = None if distance_influence in ["keys", "both"]: - self.dk_proj = nn.Linear(num_rbf, hidden_channels) + self.dk_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.dv_proj = None if distance_influence in ["values", "both"]: - self.dv_proj = nn.Linear(num_rbf, hidden_channels * 3) + self.dv_proj = nn.Linear(num_rbf, hidden_channels * 3, dtype=dtype) self.reset_parameters() diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 6f1f22853..8c78e766c 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple +import torch from torch import Tensor, nn from torch_geometric.nn import MessagePassing from torchmdnet.models.utils import ( @@ -71,6 +72,7 @@ def __init__( max_z=100, max_num_neighbors=32, aggr="add", + dtype=torch.float32 ): super(TorchMD_GN, self).__init__() @@ -103,17 +105,17 @@ def __init__( act_class = act_class_mapping[activation] - self.embedding = nn.Embedding(self.max_z, hidden_channels) + self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) self.distance = Distance( cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors ) self.distance_expansion = rbf_class_mapping[rbf_type]( - cutoff_lower, cutoff_upper, num_rbf, trainable_rbf + cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z + hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype=dtype ).jittable() if neighbor_embedding else None @@ -129,6 +131,7 @@ def __init__( cutoff_lower, cutoff_upper, aggr=self.aggr, + dtype=dtype ) self.interactions.append(block) @@ -191,12 +194,13 @@ def __init__( cutoff_lower, cutoff_upper, aggr="add", + dtype=torch.float32 ): super(InteractionBlock, self).__init__() self.mlp = nn.Sequential( - nn.Linear(num_rbf, num_filters), + nn.Linear(num_rbf, num_filters, dtype=dtype), activation(), - nn.Linear(num_filters, num_filters), + nn.Linear(num_filters, num_filters, dtype=dtype), ) self.conv = CFConv( hidden_channels, @@ -206,9 +210,10 @@ def __init__( cutoff_lower, cutoff_upper, aggr=aggr, + dtype=dtype ).jittable() self.act = activation() - self.lin = nn.Linear(hidden_channels, hidden_channels) + self.lin = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) self.reset_parameters() @@ -238,10 +243,11 @@ def __init__( cutoff_lower, cutoff_upper, aggr="add", + dtype=torch.float32 ): super(CFConv, self).__init__(aggr=aggr) - self.lin1 = nn.Linear(in_channels, num_filters, bias=False) - self.lin2 = nn.Linear(num_filters, out_channels) + self.lin1 = nn.Linear(in_channels, num_filters, bias=False, dtype=dtype) + self.lin2 = nn.Linear(num_filters, out_channels, bias=True, dtype=dtype) self.net = net self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index f6458c7d4..dce869880 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple +import torch from torch import Tensor, nn from torch_geometric.nn import MessagePassing from torchmdnet.models.utils import ( @@ -65,6 +66,7 @@ def __init__( cutoff_upper=5.0, max_z=100, max_num_neighbors=32, + dtype=torch.float ): super(TorchMD_T, self).__init__() @@ -95,17 +97,17 @@ def __init__( act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] - self.embedding = nn.Embedding(self.max_z, hidden_channels) + self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) self.distance = Distance( cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors, loop=True ) self.distance_expansion = rbf_class_mapping[rbf_type]( - cutoff_lower, cutoff_upper, num_rbf, trainable_rbf + cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z + hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype=dtype ).jittable() if neighbor_embedding else None @@ -122,10 +124,11 @@ def __init__( attn_act_class, cutoff_lower, cutoff_upper, + dtype=dtype, ).jittable() self.attention_layers.append(layer) - self.out_norm = nn.LayerNorm(hidden_channels) + self.out_norm = nn.LayerNorm(hidden_channels, dtype=dtype) self.reset_parameters() @@ -190,6 +193,7 @@ def __init__( attn_activation, cutoff_lower, cutoff_upper, + dtype=torch.float, ): super(MultiHeadAttention, self).__init__(aggr="add", node_dim=0) assert hidden_channels % num_heads == 0, ( @@ -202,23 +206,23 @@ def __init__( self.num_heads = num_heads self.head_dim = hidden_channels // num_heads - self.layernorm = nn.LayerNorm(hidden_channels) + self.layernorm = nn.LayerNorm(hidden_channels, dtype=dtype) self.act = activation() self.attn_activation = attn_activation() self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) - self.q_proj = nn.Linear(hidden_channels, hidden_channels) - self.k_proj = nn.Linear(hidden_channels, hidden_channels) - self.v_proj = nn.Linear(hidden_channels, hidden_channels) - self.o_proj = nn.Linear(hidden_channels, hidden_channels) + self.q_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.k_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.v_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.o_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) self.dk_proj = None if distance_influence in ["keys", "both"]: - self.dk_proj = nn.Linear(num_rbf, hidden_channels) + self.dk_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.dv_proj = None if distance_influence in ["values", "both"]: - self.dv_proj = nn.Linear(num_rbf, hidden_channels) + self.dv_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.reset_parameters() diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 2ad1fd382..d7cb0ec85 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -42,11 +42,11 @@ def visualize_basis(basis_type, num_rbf=50, cutoff_lower=0, cutoff_upper=5): class NeighborEmbedding(MessagePassing): - def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=100): + def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=100, dtype=torch.float32): super(NeighborEmbedding, self).__init__(aggr="add") - self.embedding = nn.Embedding(max_z, hidden_channels) - self.distance_proj = nn.Linear(num_rbf, hidden_channels) - self.combine = nn.Linear(hidden_channels * 2, hidden_channels) + self.embedding = nn.Embedding(max_z, hidden_channels, dtype=dtype) + self.distance_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) + self.combine = nn.Linear(hidden_channels * 2, hidden_channels, dtype=dtype) self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.reset_parameters() @@ -269,13 +269,13 @@ def forward(self, dist): class ExpNormalSmearing(nn.Module): - def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): + def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32): super(ExpNormalSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.num_rbf = num_rbf self.trainable = trainable - + self.dtype = dtype self.cutoff_fn = CosineCutoff(0, cutoff_upper) self.alpha = 5.0 / (cutoff_upper - cutoff_lower) @@ -291,11 +291,11 @@ def _initial_params(self): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 start_value = torch.exp( - torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower) + torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype) ) - means = torch.linspace(start_value, 1, self.num_rbf) + means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype) betas = torch.tensor( - [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf + [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf, dtype=self.dtype ) return means, betas @@ -342,13 +342,13 @@ def forward(self, distances): + 1.0 ) # remove contributions below the cutoff radius - cutoffs = cutoffs * (distances < self.cutoff_upper).float() - cutoffs = cutoffs * (distances > self.cutoff_lower).float() + cutoffs = cutoffs * (distances < self.cutoff_upper) + cutoffs = cutoffs * (distances > self.cutoff_lower) return cutoffs else: cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) # remove contributions beyond the cutoff radius - cutoffs = cutoffs * (distances < self.cutoff_upper).float() + cutoffs = cutoffs * (distances < self.cutoff_upper) return cutoffs @@ -425,6 +425,7 @@ def __init__( intermediate_channels=None, activation="silu", scalar_activation=False, + dtype=torch.float, ): super(GatedEquivariantBlock, self).__init__() self.out_channels = out_channels @@ -432,14 +433,14 @@ def __init__( if intermediate_channels is None: intermediate_channels = hidden_channels - self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) - self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) + self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) + self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False, dtype=dtype) act_class = act_class_mapping[activation] self.update_net = nn.Sequential( - nn.Linear(hidden_channels * 2, intermediate_channels), + nn.Linear(hidden_channels * 2, intermediate_channels, dtype=dtype), act_class(), - nn.Linear(intermediate_channels, out_channels * 2), + nn.Linear(intermediate_channels, out_channels * 2, dtype=dtype), ) self.act = act_class() if scalar_activation else None @@ -457,7 +458,7 @@ def forward(self, x, v): # detach zero-entries to avoid NaN gradients during force loss backpropagation vec1 = torch.zeros( - vec1_buffer.size(0), vec1_buffer.size(2), device=vec1_buffer.device + vec1_buffer.size(0), vec1_buffer.size(2), device=vec1_buffer.device, dtype=vec1_buffer.dtype ) mask = (vec1_buffer != 0).view(vec1_buffer.size(0), -1).any(dim=1) if not mask.all(): @@ -489,3 +490,5 @@ def forward(self, x, v): "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, } + +dtype_mapping = {"float": torch.float, "double": torch.float64, "float32": torch.float32, "float64": torch.float64} diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 3d6f91937..86dfa0d6d 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -12,9 +12,9 @@ from torchmdnet.data import DataModule from torchmdnet.models import output_modules from torchmdnet.models.model import create_prior_models -from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping +from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number - +import torch def get_args(): # fmt: off @@ -66,6 +66,7 @@ def get_args(): parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') # architectural args + parser.add_argument('--dtype', type=str, default="float", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float or float64') parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') From c79fecbe3ba1a64425156adf1867e3aa0065b902 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 14:11:19 +0200 Subject: [PATCH 03/16] Default dtype to float in create_model --- torchmdnet/models/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index be0cf72af..b0330a85d 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -25,6 +25,7 @@ def create_model(args, prior_model=None, mean=None, std=None): ------- nn.Module: An instance of the TorchMD_Net model. """ + args["dtype"] = "float" if "dtype" not in args else args["dtype"] #Use mapping if args["dtype"] is a string, else use args["dtype"] args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"] shared_args = dict( From 0326dd4a983ff9525b06cd22555c3716d4908a85 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 14:11:34 +0200 Subject: [PATCH 04/16] Add dtype to EquivariantVectorOutput --- torchmdnet/models/output_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index f2d51a058..d3d7e1b85 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -181,9 +181,9 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): class EquivariantVectorOutput(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(EquivariantVectorOutput, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op="sum" + hidden_channels, activation, allow_prior_model=False, reduce_op="sum", dtype=dtype ) def pre_reduce(self, x, v, z, pos, batch): From 90543de690b6b14721253c4fe82909a9919f3ac3 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 14:24:53 +0200 Subject: [PATCH 05/16] Add GaussianSmearing --- torchmdnet/models/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index d7cb0ec85..60566602e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -238,13 +238,13 @@ def forward( class GaussianSmearing(nn.Module): - def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): + def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32): super(GaussianSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.num_rbf = num_rbf self.trainable = trainable - + self.dtype = dtype offset, coeff = self._initial_params() if trainable: self.register_parameter("coeff", nn.Parameter(coeff)) @@ -254,7 +254,7 @@ def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=Tru self.register_buffer("offset", offset) def _initial_params(self): - offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf) + offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf, dtype=self.dtype) coeff = -0.5 / (offset[1] - offset[0]) ** 2 return offset, coeff From 170a059742ca90041c8a610ec7821ac0d2e2251f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 16:03:27 +0200 Subject: [PATCH 06/16] Change default to float32 --- torchmdnet/models/model.py | 2 +- torchmdnet/scripts/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index b0330a85d..7c2a93bdc 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -25,7 +25,7 @@ def create_model(args, prior_model=None, mean=None, std=None): ------- nn.Module: An instance of the TorchMD_Net model. """ - args["dtype"] = "float" if "dtype" not in args else args["dtype"] + args["dtype"] = "float32" if "dtype" not in args else args["dtype"] #Use mapping if args["dtype"] is a string, else use args["dtype"] args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"] shared_args = dict( diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 86dfa0d6d..21466a6b5 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -66,7 +66,7 @@ def get_args(): parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') # architectural args - parser.add_argument('--dtype', type=str, default="float", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float or float64') + parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float or float64') parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') From eaf867e5febcd17e23b9bda87e9489287cc92a97 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 6 Jun 2023 16:04:01 +0200 Subject: [PATCH 07/16] Update some comments --- torchmdnet/models/model.py | 1 - torchmdnet/scripts/train.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 7c2a93bdc..384d84afd 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -26,7 +26,6 @@ def create_model(args, prior_model=None, mean=None, std=None): nn.Module: An instance of the TorchMD_Net model. """ args["dtype"] = "float32" if "dtype" not in args else args["dtype"] - #Use mapping if args["dtype"] is a string, else use args["dtype"] args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"] shared_args = dict( hidden_channels=args["embedding_dimension"], diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 21466a6b5..3ff371667 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -66,7 +66,7 @@ def get_args(): parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') # architectural args - parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float or float64') + parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64') parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') From b3283527a58b6fbdacce5cbea4598cd584cc4c29 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 8 Jun 2023 12:09:37 +0200 Subject: [PATCH 08/16] Add double support for D2 --- tests/priors.yaml | 1 + tests/test_priors.py | 5 +++-- torchmdnet/priors/d2.py | 13 +++++++++---- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/priors.yaml b/tests/priors.yaml index 1b49d97b0..a72a7996a 100644 --- a/tests/priors.yaml +++ b/tests/priors.yaml @@ -57,3 +57,4 @@ train_size: 110000 trainable_rbf: false val_size: 10000 weight_decay: 0.0 +dtype: float diff --git a/tests/test_priors.py b/tests/test_priors.py index 493bd613f..e72bcabd5 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -63,12 +63,13 @@ def compute_interaction(pos1, pos2, z1, z2): expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) torch.testing.assert_allclose(expected, energy) -def test_multiple_priors(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_multiple_priors(dtype): # Create a model from a config file. dataset = DummyDataset(has_atomref=True) config_file = join(dirname(__file__), 'priors.yaml') - args = load_example_args('equivariant-transformer', config_file=config_file) + args = load_example_args('equivariant-transformer', config_file=config_file, dtype=dtype) prior_models = create_prior_models(args, dataset) args['prior_args'] = [p.get_init_args() for p in prior_models] model = LNNP(args, prior_model=prior_models) diff --git a/torchmdnet/priors/d2.py b/torchmdnet/priors/d2.py index 91a72aad6..f526c0d1e 100644 --- a/torchmdnet/priors/d2.py +++ b/torchmdnet/priors/d2.py @@ -103,7 +103,8 @@ class D2(BasePrior): [31.74, 1.892], # 52 Te [31.50, 1.892], # 53 I [29.99, 1.881], # 54 Xe - ] + ], dtype=pt.float64 + ) C_6_R_r[:, 1] *= 0.1 # Å --> nm @@ -115,17 +116,21 @@ def __init__( distance_scale=None, energy_scale=None, dataset=None, + dtype=pt.float32, ): super().__init__() - self.cutoff_distance = float(cutoff_distance) + one = pt.tensor(1.0, dtype=dtype).item() + self.cutoff_distance = cutoff_distance * one self.max_num_neighbors = int(max_num_neighbors) + + self.C_6_R_r = self.C_6_R_r.to(dtype=dtype) self.atomic_number = list( dataset.atomic_number if atomic_number is None else atomic_number ) - self.distance_scale = float( + self.distance_scale = one * ( dataset.distance_scale if distance_scale is None else distance_scale ) - self.energy_scale = float( + self.energy_scale = one * ( dataset.energy_scale if energy_scale is None else energy_scale ) From ef4b099b5df93d107f5e13fc4eee9acd0fab97f1 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 8 Jun 2023 13:38:10 +0200 Subject: [PATCH 09/16] Test Coulomb prior also in double precision --- tests/test_priors.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_priors.py b/tests/test_priors.py index ba7e1405b..1d6b477f9 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -63,9 +63,10 @@ def compute_interaction(pos1, pos2, z1, z2): expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) torch.testing.assert_allclose(expected, energy) -def test_coulomb(): - pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=torch.float32) # Atom positions in nm - charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=torch.float32) # Partial charges +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_coulomb(dtype): + pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=dtype) # Atom positions in nm + charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=dtype) # Partial charges types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types distance_scale = 1e-9 # Convert nm to meters energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules From 953f52c8d322046c1e46969b7fe5e1e9e5d703ee Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 8 Jun 2023 13:40:28 +0200 Subject: [PATCH 10/16] Remove unnecessary import --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 07ceb50cf..bfa567656 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -5,7 +5,7 @@ import torch import pytorch_lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model, load_model +from torchmdnet.models.model import create_model from torchmdnet.models import output_modules from utils import load_example_args, create_example_batch From f9531130376fb1472380599eb1ec468ffe14b44b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 8 Jun 2023 13:41:04 +0200 Subject: [PATCH 11/16] Fix formatting Acked-by: RaulPPealez --- tests/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model.py b/tests/test_model.py index bfa567656..65fa986b1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -104,6 +104,7 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): deriv, expected[model_name][output_model]["deriv"] ) + @mark.parametrize("model_name", models.__all__) def test_gradients(model_name): pl.seed_everything(1234) From bffbd6a800e564f22ac544b57f1ae17af321d125 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 8 Jun 2023 13:44:45 +0200 Subject: [PATCH 12/16] representation_model.to(dtype) in TorchMD_Net --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 384d84afd..a8bbb3ace 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -177,7 +177,7 @@ def __init__( dtype=torch.float32, ): super(TorchMD_Net, self).__init__() - self.representation_model = representation_model + self.representation_model = representation_model.to(dtype=dtype) self.output_model = output_model.to(dtype=dtype) if not output_model.allow_prior_model and prior_model is not None: From 45efcb483771eeb1d1e4373e596a8cbcbc38962b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 8 Jun 2023 13:45:05 +0200 Subject: [PATCH 13/16] Add more torch.float64 to test_model --- tests/test_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 65fa986b1..3d35d6cac 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,18 +28,20 @@ def test_forward(model_name, use_batch, explicit_q_s, dtype): @mark.parametrize("model_name", models.__all__) @mark.parametrize("output_model", output_modules.__all__) -def test_forward_output_modules(model_name, output_model): +@mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_forward_output_modules(model_name, output_model, dtype): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model) + args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype) model = create_model(args) model(z, pos, batch=batch) @mark.parametrize("model_name", models.__all__) -def test_forward_torchscript(model_name): +@mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_forward_torchscript(model_name, dtype): z, pos, batch = create_example_batch() model = torch.jit.script( - create_model(load_example_args(model_name, remove_prior=True, derivative=True)) + create_model(load_example_args(model_name, remove_prior=True, derivative=True, dtype=dtype)) ) model(z, pos, batch=batch) From afc1cb5e29dbcfa0d9d72e2c906119511448c4ff Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 14 Jun 2023 11:08:22 +0200 Subject: [PATCH 14/16] Default to torch.float32 in test.utils.load_example_args --- tests/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 5d0ab41c4..b15560adb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,6 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") with open(config_file, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) + if "dtype" not in args: + args["dtype"] = torch.float32 args["model"] = model_name args["seed"] = 1234 if remove_prior: From 6b43c24d052af5707ce5e505be76fd4c7d8c2a4e Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 14 Jun 2023 11:09:00 +0200 Subject: [PATCH 15/16] Make TensorNet compatible with float64 --- torchmdnet/models/tensornet.py | 59 ++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 956f683bd..0b0f4c932 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -15,7 +15,7 @@ def vector_to_skewtensor(vector): tensor = torch.cross( *torch.broadcast_tensors( - vector[..., None], torch.eye(3, 3, device=vector.device)[None, None] + vector[..., None], torch.eye(3, 3, device=vector.device, dtype=vector.dtype)[None, None] ) ) return tensor.squeeze(0) @@ -26,7 +26,7 @@ def vector_to_symtensor(vector): tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ ..., None, None - ] * torch.eye(3, 3, device=tensor.device) + ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return S @@ -35,7 +35,7 @@ def vector_to_symtensor(vector): def decompose_tensor(tensor): I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ ..., None, None - ] * torch.eye(3, 3, device=tensor.device) + ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) A = 0.5 * (tensor - tensor.transpose(-2, -1)) S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return I, A, S @@ -85,7 +85,7 @@ class TensorNet(nn.Module): will be invariant. O(3) or SO(3). (default :obj:`"O(3)"`) """ - + def __init__( self, hidden_channels=128, @@ -99,6 +99,7 @@ def __init__( max_z=128, max_num_neighbors=64, equivariance_invariance_group="O(3)", + dtype=torch.float32, ): super(TensorNet, self).__init__() assert rbf_type in rbf_class_mapping, ( @@ -136,6 +137,7 @@ def __init__( cutoff_upper, trainable_rbf, max_z, + dtype, ).jittable() self.layers = nn.ModuleList() if num_layers != 0: @@ -148,10 +150,11 @@ def __init__( cutoff_lower, cutoff_upper, equivariance_invariance_group, + dtype, ).jittable() ) - self.linear = nn.Linear(3 * hidden_channels, hidden_channels) - self.out_norm = nn.LayerNorm(3 * hidden_channels) + self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype) + self.out_norm = nn.LayerNorm(3 * hidden_channels, dtype=dtype) self.act = act_class() self.reset_parameters() @@ -191,36 +194,37 @@ def __init__( cutoff_upper, trainable_rbf=False, max_z=128, + dtype=torch.float32, ): super(TensorEmbedding, self).__init__(aggr="add", node_dim=0) self.hidden_channels = hidden_channels - self.distance_proj1 = nn.Linear(num_rbf, hidden_channels) - self.distance_proj2 = nn.Linear(num_rbf, hidden_channels) - self.distance_proj3 = nn.Linear(num_rbf, hidden_channels) + self.distance_proj1 = nn.Linear(num_rbf, hidden_channels, dtype=dtype) + self.distance_proj2 = nn.Linear(num_rbf, hidden_channels, dtype=dtype) + self.distance_proj3 = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.max_z = max_z - self.emb = torch.nn.Embedding(max_z, hidden_channels) - self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels) + self.emb = torch.nn.Embedding(max_z, hidden_channels, dtype=dtype) + self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() self.linears_tensor = nn.ModuleList() self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_scalar = nn.ModuleList() self.linears_scalar.append( - nn.Linear(hidden_channels, 2 * hidden_channels, bias=True) + nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) ) self.linears_scalar.append( - nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True) + nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) - self.init_norm = nn.LayerNorm(hidden_channels) + self.init_norm = nn.LayerNorm(hidden_channels, dtype=dtype) self.reset_parameters() def reset_parameters(self): @@ -244,7 +248,7 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr): mask = edge_index[0] != edge_index[1] edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) Iij, Aij, Sij = new_radial_tensor( - torch.eye(3, 3, device=edge_vec.device)[None, None, :, :], + torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[None, None, :, :], vector_to_skewtensor(edge_vec)[..., None, :, :], vector_to_symtensor(edge_vec)[..., None, :, :], W1, @@ -301,6 +305,7 @@ def __init__( cutoff_lower, cutoff_upper, equivariance_invariance_group, + dtype=torch.float32, ): super(Interaction, self).__init__(aggr="add", node_dim=0) @@ -308,31 +313,31 @@ def __init__( self.hidden_channels = hidden_channels self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.linears_scalar = nn.ModuleList() - self.linears_scalar.append(nn.Linear(num_rbf, hidden_channels, bias=True)) + self.linears_scalar.append(nn.Linear(num_rbf, hidden_channels, bias=True, dtype=dtype)) self.linears_scalar.append( - nn.Linear(hidden_channels, 2 * hidden_channels, bias=True) + nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) ) self.linears_scalar.append( - nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True) + nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) self.linears_tensor = nn.ModuleList() self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group From 37a6ad94ab02eaea1a94d7ded8bbd766032fb8e4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 14 Jun 2023 11:21:07 +0200 Subject: [PATCH 16/16] Default to "float" in test.utils.load_example_args --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index b15560adb..e5fa4ee1e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,7 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs with open(config_file, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) if "dtype" not in args: - args["dtype"] = torch.float32 + args["dtype"] = "float" args["model"] = model_name args["seed"] = 1234 if remove_prior: