diff --git a/examples/ET-QM9.yaml b/examples/ET-QM9.yaml index ef9048578..24d4ba242 100644 --- a/examples/ET-QM9.yaml +++ b/examples/ET-QM9.yaml @@ -55,3 +55,4 @@ train_size: 110000 trainable_rbf: false val_size: 10000 weight_decay: 0.0 +dtype: float diff --git a/tests/priors.yaml b/tests/priors.yaml index 1b49d97b0..a72a7996a 100644 --- a/tests/priors.yaml +++ b/tests/priors.yaml @@ -57,3 +57,4 @@ train_size: 110000 trainable_rbf: false val_size: 10000 weight_decay: 0.0 +dtype: float diff --git a/tests/test_model.py b/tests/test_model.py index 901641f5c..489262c85 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -14,9 +14,11 @@ @mark.parametrize("model_name", models.__all__) @mark.parametrize("use_batch", [True, False]) @mark.parametrize("explicit_q_s", [True, False]) -def test_forward(model_name, use_batch, explicit_q_s): +@mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_forward(model_name, use_batch, explicit_q_s, dtype): z, pos, batch = create_example_batch() - model = create_model(load_example_args(model_name, prior_model=None)) + pos = pos.to(dtype=dtype) + model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype)) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -26,20 +28,22 @@ def test_forward(model_name, use_batch, explicit_q_s): @mark.parametrize("model_name", models.__all__) @mark.parametrize("output_model", output_modules.__all__) -def test_forward_output_modules(model_name, output_model): +@mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_forward_output_modules(model_name, output_model, dtype): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model) + args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype) model = create_model(args) model(z, pos, batch=batch) @mark.parametrize("model_name", models.__all__) -def test_forward_torchscript(model_name): +@mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_forward_torchscript(model_name, dtype): if model_name == "tensornet": pytest.skip("TensorNet does not support torchscript.") z, pos, batch = create_example_batch() model = torch.jit.script( - create_model(load_example_args(model_name, remove_prior=True, derivative=True)) + create_model(load_example_args(model_name, remove_prior=True, derivative=True, dtype=dtype)) ) model(z, pos, batch=batch) @@ -107,3 +111,26 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): torch.testing.assert_allclose( deriv, expected[model_name][output_model]["deriv"] ) + + +@mark.parametrize("model_name", models.__all__) +def test_gradients(model_name): + pl.seed_everything(1234) + dtype = torch.float64 + output_model = "Scalar" + # create model and sample batch + derivative = output_model in ["Scalar", "EquivariantScalar"] + args = load_example_args( + model_name, + remove_prior=True, + output_model=output_model, + derivative=derivative, + dtype=dtype, + ) + model = create_model(args) + z, pos, batch = create_example_batch(n_atoms=5) + pos.requires_grad_(True) + pos = pos.to(dtype) + torch.autograd.gradcheck( + model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 + ) diff --git a/tests/test_priors.py b/tests/test_priors.py index e2eda4bf3..1d6b477f9 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -63,9 +63,10 @@ def compute_interaction(pos1, pos2, z1, z2): expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) torch.testing.assert_allclose(expected, energy) -def test_coulomb(): - pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=torch.float32) # Atom positions in nm - charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=torch.float32) # Partial charges +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_coulomb(dtype): + pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=dtype) # Atom positions in nm + charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=dtype) # Partial charges types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types distance_scale = 1e-9 # Convert nm to meters energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules @@ -89,12 +90,14 @@ def compute_interaction(pos1, pos2, z1, z2): expected += compute_interaction(pos[i], pos[j], charge[i], charge[j]) torch.testing.assert_allclose(expected, energy) -def test_multiple_priors(): + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_multiple_priors(dtype): # Create a model from a config file. dataset = DummyDataset(has_atomref=True) config_file = join(dirname(__file__), 'priors.yaml') - args = load_example_args('equivariant-transformer', config_file=config_file) + args = load_example_args('equivariant-transformer', config_file=config_file, dtype=dtype) prior_models = create_prior_models(args, dataset) args['prior_args'] = [p.get_init_args() for p in prior_models] model = LNNP(args, prior_model=prior_models) diff --git a/tests/utils.py b/tests/utils.py index 5d0ab41c4..e5fa4ee1e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,6 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") with open(config_file, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) + if "dtype" not in args: + args["dtype"] = "float" args["model"] = model_name args["seed"] = 1234 if remove_prior: diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 3340745e7..27992e735 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -7,11 +7,26 @@ from pytorch_lightning.utilities import rank_zero_warn from torchmdnet.models import output_modules from torchmdnet.models.wrappers import AtomFilter +from torchmdnet.models.utils import dtype_mapping from torchmdnet import priors import warnings def create_model(args, prior_model=None, mean=None, std=None): + """Create a model from the given arguments. + See :func:`get_args` in scripts/train.py for a description of the arguments. + Parameters + ---------- + args (dict): Arguments for the model. + prior_model (nn.Module, optional): Prior model to use. Defaults to None. + mean (torch.Tensor, optional): Mean of the training data. Defaults to None. + std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None. + Returns + ------- + nn.Module: An instance of the TorchMD_Net model. + """ + args["dtype"] = "float32" if "dtype" not in args else args["dtype"] + args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"] shared_args = dict( hidden_channels=args["embedding_dimension"], num_layers=args["num_layers"], @@ -23,6 +38,7 @@ def create_model(args, prior_model=None, mean=None, std=None): cutoff_upper=args["cutoff_upper"], max_z=args["max_z"], max_num_neighbors=args["max_num_neighbors"], + dtype=args["dtype"] ) # representation network @@ -86,6 +102,7 @@ def create_model(args, prior_model=None, mean=None, std=None): args["embedding_dimension"], activation=args["activation"], reduce_op=args["reduce_op"], + dtype=args["dtype"], ) # combine representation and output network @@ -96,6 +113,7 @@ def create_model(args, prior_model=None, mean=None, std=None): mean=mean, std=std, derivative=args["derivative"], + dtype=args["dtype"], ) return model @@ -177,10 +195,11 @@ def __init__( mean=None, std=None, derivative=False, + dtype=torch.float32, ): super(TorchMD_Net, self).__init__() - self.representation_model = representation_model - self.output_model = output_model + self.representation_model = representation_model.to(dtype=dtype) + self.output_model = output_model.to(dtype=dtype) if not output_model.allow_prior_model and prior_model is not None: prior_model = None @@ -192,14 +211,14 @@ def __init__( ) if isinstance(prior_model, priors.base.BasePrior): prior_model = [prior_model] - self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model) + self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype) self.derivative = derivative mean = torch.scalar_tensor(0) if mean is None else mean - self.register_buffer("mean", mean) + self.register_buffer("mean", mean.to(dtype=dtype)) std = torch.scalar_tensor(1) if std is None else std - self.register_buffer("std", std) + self.register_buffer("std", std.to(dtype=dtype)) self.reset_parameters() @@ -277,6 +296,7 @@ def forward( )[0] if dy is None: raise RuntimeError("Autograd returned None for the force prediction.") + return y, -dy # TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) return y, None diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index 0bcec935f..d3d7e1b85 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -38,15 +38,16 @@ def __init__( activation="silu", allow_prior_model=True, reduce_op="sum", + dtype=torch.float ): super(Scalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op ) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2), + nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), act_class(), - nn.Linear(hidden_channels // 2, 1), + nn.Linear(hidden_channels // 2, 1, dtype=dtype), ) self.reset_parameters() @@ -68,6 +69,7 @@ def __init__( activation="silu", allow_prior_model=True, reduce_op="sum", + dtype=torch.float ): super(EquivariantScalar, self).__init__( allow_prior_model=allow_prior_model, reduce_op=reduce_op @@ -79,8 +81,9 @@ def __init__( hidden_channels // 2, activation=activation, scalar_activation=True, + dtype=dtype ), - GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation), + GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation, dtype=dtype), ] ) @@ -98,11 +101,11 @@ def pre_reduce(self, x, v, z, pos, batch): class DipoleMoment(Scalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(DipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op + hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype ) - atomic_mass = torch.from_numpy(atomic_masses).float() + atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch): @@ -119,11 +122,11 @@ def post_reduce(self, x): class EquivariantDipoleMoment(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(EquivariantDipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op + hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype ) - atomic_mass = torch.from_numpy(atomic_masses).float() + atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) def pre_reduce(self, x, v, z, pos, batch): @@ -141,17 +144,17 @@ def post_reduce(self, x): class ElectronicSpatialExtent(OutputModel): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(ElectronicSpatialExtent, self).__init__( allow_prior_model=False, reduce_op=reduce_op ) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( - nn.Linear(hidden_channels, hidden_channels // 2), + nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), act_class(), - nn.Linear(hidden_channels // 2, 1), + nn.Linear(hidden_channels // 2, 1, dtype=dtype), ) - atomic_mass = torch.from_numpy(atomic_masses).float() + atomic_mass = torch.from_numpy(atomic_masses).to(dtype) self.register_buffer("atomic_mass", atomic_mass) self.reset_parameters() @@ -178,9 +181,9 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): class EquivariantVectorOutput(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float): super(EquivariantVectorOutput, self).__init__( - hidden_channels, activation, allow_prior_model=False, reduce_op="sum" + hidden_channels, activation, allow_prior_model=False, reduce_op="sum", dtype=dtype ) def pre_reduce(self, x, v, z, pos, batch): diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 956f683bd..0b0f4c932 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -15,7 +15,7 @@ def vector_to_skewtensor(vector): tensor = torch.cross( *torch.broadcast_tensors( - vector[..., None], torch.eye(3, 3, device=vector.device)[None, None] + vector[..., None], torch.eye(3, 3, device=vector.device, dtype=vector.dtype)[None, None] ) ) return tensor.squeeze(0) @@ -26,7 +26,7 @@ def vector_to_symtensor(vector): tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ ..., None, None - ] * torch.eye(3, 3, device=tensor.device) + ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return S @@ -35,7 +35,7 @@ def vector_to_symtensor(vector): def decompose_tensor(tensor): I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ ..., None, None - ] * torch.eye(3, 3, device=tensor.device) + ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) A = 0.5 * (tensor - tensor.transpose(-2, -1)) S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return I, A, S @@ -85,7 +85,7 @@ class TensorNet(nn.Module): will be invariant. O(3) or SO(3). (default :obj:`"O(3)"`) """ - + def __init__( self, hidden_channels=128, @@ -99,6 +99,7 @@ def __init__( max_z=128, max_num_neighbors=64, equivariance_invariance_group="O(3)", + dtype=torch.float32, ): super(TensorNet, self).__init__() assert rbf_type in rbf_class_mapping, ( @@ -136,6 +137,7 @@ def __init__( cutoff_upper, trainable_rbf, max_z, + dtype, ).jittable() self.layers = nn.ModuleList() if num_layers != 0: @@ -148,10 +150,11 @@ def __init__( cutoff_lower, cutoff_upper, equivariance_invariance_group, + dtype, ).jittable() ) - self.linear = nn.Linear(3 * hidden_channels, hidden_channels) - self.out_norm = nn.LayerNorm(3 * hidden_channels) + self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype) + self.out_norm = nn.LayerNorm(3 * hidden_channels, dtype=dtype) self.act = act_class() self.reset_parameters() @@ -191,36 +194,37 @@ def __init__( cutoff_upper, trainable_rbf=False, max_z=128, + dtype=torch.float32, ): super(TensorEmbedding, self).__init__(aggr="add", node_dim=0) self.hidden_channels = hidden_channels - self.distance_proj1 = nn.Linear(num_rbf, hidden_channels) - self.distance_proj2 = nn.Linear(num_rbf, hidden_channels) - self.distance_proj3 = nn.Linear(num_rbf, hidden_channels) + self.distance_proj1 = nn.Linear(num_rbf, hidden_channels, dtype=dtype) + self.distance_proj2 = nn.Linear(num_rbf, hidden_channels, dtype=dtype) + self.distance_proj3 = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.max_z = max_z - self.emb = torch.nn.Embedding(max_z, hidden_channels) - self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels) + self.emb = torch.nn.Embedding(max_z, hidden_channels, dtype=dtype) + self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() self.linears_tensor = nn.ModuleList() self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_scalar = nn.ModuleList() self.linears_scalar.append( - nn.Linear(hidden_channels, 2 * hidden_channels, bias=True) + nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) ) self.linears_scalar.append( - nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True) + nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) - self.init_norm = nn.LayerNorm(hidden_channels) + self.init_norm = nn.LayerNorm(hidden_channels, dtype=dtype) self.reset_parameters() def reset_parameters(self): @@ -244,7 +248,7 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr): mask = edge_index[0] != edge_index[1] edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) Iij, Aij, Sij = new_radial_tensor( - torch.eye(3, 3, device=edge_vec.device)[None, None, :, :], + torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[None, None, :, :], vector_to_skewtensor(edge_vec)[..., None, :, :], vector_to_symtensor(edge_vec)[..., None, :, :], W1, @@ -301,6 +305,7 @@ def __init__( cutoff_lower, cutoff_upper, equivariance_invariance_group, + dtype=torch.float32, ): super(Interaction, self).__init__(aggr="add", node_dim=0) @@ -308,31 +313,31 @@ def __init__( self.hidden_channels = hidden_channels self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.linears_scalar = nn.ModuleList() - self.linears_scalar.append(nn.Linear(num_rbf, hidden_channels, bias=True)) + self.linears_scalar.append(nn.Linear(num_rbf, hidden_channels, bias=True, dtype=dtype)) self.linears_scalar.append( - nn.Linear(hidden_channels, 2 * hidden_channels, bias=True) + nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) ) self.linears_scalar.append( - nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True) + nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) self.linears_tensor = nn.ModuleList() self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) + nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 69977be93..9f93028cb 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -11,7 +11,6 @@ act_class_mapping, ) - class TorchMD_ET(nn.Module): r"""The TorchMD equivariant Transformer architecture. @@ -67,6 +66,7 @@ def __init__( cutoff_upper=5.0, max_z=100, max_num_neighbors=32, + dtype=torch.float32, ): super(TorchMD_ET, self).__init__() @@ -97,10 +97,11 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z + self.dtype = dtype act_class = act_class_mapping[activation] - self.embedding = nn.Embedding(self.max_z, hidden_channels) + self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) self.distance = Distance( cutoff_lower, @@ -114,7 +115,7 @@ def __init__( ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z + hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype ).jittable() if neighbor_embedding else None @@ -131,10 +132,11 @@ def __init__( attn_activation, cutoff_lower, cutoff_upper, + dtype, ).jittable() self.attention_layers.append(layer) - self.out_norm = nn.LayerNorm(hidden_channels) + self.out_norm = nn.LayerNorm(hidden_channels, dtype=dtype) self.reset_parameters() @@ -147,6 +149,7 @@ def reset_parameters(self): attn.reset_parameters() self.out_norm.reset_parameters() + def forward( self, z: Tensor, @@ -159,6 +162,8 @@ def forward( x = self.embedding(z) edge_index, edge_weight, edge_vec = self.distance(pos, batch) + # This assert must be here to convince TorchScript that edge_vec is not None + # If you remove it TorchScript will complain down below that you cannot use an Optional[Tensor] assert ( edge_vec is not None ), "Distance module did not return directional information" @@ -170,7 +175,7 @@ def forward( if self.neighbor_embedding is not None: x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) - vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device, dtype=x.dtype) for attn in self.attention_layers: dx, dvec = attn(x, vec, edge_index, edge_weight, edge_attr, edge_vec) @@ -194,7 +199,8 @@ def __repr__(self): f"num_heads={self.num_heads}, " f"distance_influence={self.distance_influence}, " f"cutoff_lower={self.cutoff_lower}, " - f"cutoff_upper={self.cutoff_upper})" + f"cutoff_upper={self.cutoff_upper}), " + f"dtype={self.dtype}" ) @@ -209,6 +215,7 @@ def __init__( attn_activation, cutoff_lower, cutoff_upper, + dtype=torch.float32, ): super(EquivariantMultiHeadAttention, self).__init__(aggr="add", node_dim=0) assert hidden_channels % num_heads == 0, ( @@ -221,26 +228,25 @@ def __init__( self.num_heads = num_heads self.hidden_channels = hidden_channels self.head_dim = hidden_channels // num_heads - - self.layernorm = nn.LayerNorm(hidden_channels) + self.layernorm = nn.LayerNorm(hidden_channels, dtype=dtype) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) - self.q_proj = nn.Linear(hidden_channels, hidden_channels) - self.k_proj = nn.Linear(hidden_channels, hidden_channels) - self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3) - self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) + self.q_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.k_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3, dtype=dtype) + self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3, dtype=dtype) - self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False) + self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False, dtype=dtype) self.dk_proj = None if distance_influence in ["keys", "both"]: - self.dk_proj = nn.Linear(num_rbf, hidden_channels) + self.dk_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.dv_proj = None if distance_influence in ["values", "both"]: - self.dv_proj = nn.Linear(num_rbf, hidden_channels * 3) + self.dv_proj = nn.Linear(num_rbf, hidden_channels * 3, dtype=dtype) self.reset_parameters() diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 6f1f22853..8c78e766c 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple +import torch from torch import Tensor, nn from torch_geometric.nn import MessagePassing from torchmdnet.models.utils import ( @@ -71,6 +72,7 @@ def __init__( max_z=100, max_num_neighbors=32, aggr="add", + dtype=torch.float32 ): super(TorchMD_GN, self).__init__() @@ -103,17 +105,17 @@ def __init__( act_class = act_class_mapping[activation] - self.embedding = nn.Embedding(self.max_z, hidden_channels) + self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) self.distance = Distance( cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors ) self.distance_expansion = rbf_class_mapping[rbf_type]( - cutoff_lower, cutoff_upper, num_rbf, trainable_rbf + cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z + hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype=dtype ).jittable() if neighbor_embedding else None @@ -129,6 +131,7 @@ def __init__( cutoff_lower, cutoff_upper, aggr=self.aggr, + dtype=dtype ) self.interactions.append(block) @@ -191,12 +194,13 @@ def __init__( cutoff_lower, cutoff_upper, aggr="add", + dtype=torch.float32 ): super(InteractionBlock, self).__init__() self.mlp = nn.Sequential( - nn.Linear(num_rbf, num_filters), + nn.Linear(num_rbf, num_filters, dtype=dtype), activation(), - nn.Linear(num_filters, num_filters), + nn.Linear(num_filters, num_filters, dtype=dtype), ) self.conv = CFConv( hidden_channels, @@ -206,9 +210,10 @@ def __init__( cutoff_lower, cutoff_upper, aggr=aggr, + dtype=dtype ).jittable() self.act = activation() - self.lin = nn.Linear(hidden_channels, hidden_channels) + self.lin = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) self.reset_parameters() @@ -238,10 +243,11 @@ def __init__( cutoff_lower, cutoff_upper, aggr="add", + dtype=torch.float32 ): super(CFConv, self).__init__(aggr=aggr) - self.lin1 = nn.Linear(in_channels, num_filters, bias=False) - self.lin2 = nn.Linear(num_filters, out_channels) + self.lin1 = nn.Linear(in_channels, num_filters, bias=False, dtype=dtype) + self.lin2 = nn.Linear(num_filters, out_channels, bias=True, dtype=dtype) self.net = net self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index f6458c7d4..dce869880 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple +import torch from torch import Tensor, nn from torch_geometric.nn import MessagePassing from torchmdnet.models.utils import ( @@ -65,6 +66,7 @@ def __init__( cutoff_upper=5.0, max_z=100, max_num_neighbors=32, + dtype=torch.float ): super(TorchMD_T, self).__init__() @@ -95,17 +97,17 @@ def __init__( act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] - self.embedding = nn.Embedding(self.max_z, hidden_channels) + self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) self.distance = Distance( cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors, loop=True ) self.distance_expansion = rbf_class_mapping[rbf_type]( - cutoff_lower, cutoff_upper, num_rbf, trainable_rbf + cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype ) self.neighbor_embedding = ( NeighborEmbedding( - hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z + hidden_channels, num_rbf, cutoff_lower, cutoff_upper, self.max_z, dtype=dtype ).jittable() if neighbor_embedding else None @@ -122,10 +124,11 @@ def __init__( attn_act_class, cutoff_lower, cutoff_upper, + dtype=dtype, ).jittable() self.attention_layers.append(layer) - self.out_norm = nn.LayerNorm(hidden_channels) + self.out_norm = nn.LayerNorm(hidden_channels, dtype=dtype) self.reset_parameters() @@ -190,6 +193,7 @@ def __init__( attn_activation, cutoff_lower, cutoff_upper, + dtype=torch.float, ): super(MultiHeadAttention, self).__init__(aggr="add", node_dim=0) assert hidden_channels % num_heads == 0, ( @@ -202,23 +206,23 @@ def __init__( self.num_heads = num_heads self.head_dim = hidden_channels // num_heads - self.layernorm = nn.LayerNorm(hidden_channels) + self.layernorm = nn.LayerNorm(hidden_channels, dtype=dtype) self.act = activation() self.attn_activation = attn_activation() self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) - self.q_proj = nn.Linear(hidden_channels, hidden_channels) - self.k_proj = nn.Linear(hidden_channels, hidden_channels) - self.v_proj = nn.Linear(hidden_channels, hidden_channels) - self.o_proj = nn.Linear(hidden_channels, hidden_channels) + self.q_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.k_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.v_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) + self.o_proj = nn.Linear(hidden_channels, hidden_channels, dtype=dtype) self.dk_proj = None if distance_influence in ["keys", "both"]: - self.dk_proj = nn.Linear(num_rbf, hidden_channels) + self.dk_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.dv_proj = None if distance_influence in ["values", "both"]: - self.dv_proj = nn.Linear(num_rbf, hidden_channels) + self.dv_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) self.reset_parameters() diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 912d0c147..c0aa439b1 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -41,25 +41,24 @@ def visualize_basis(basis_type, num_rbf=50, cutoff_lower=0, cutoff_upper=5): class NeighborEmbedding(MessagePassing): - """ - The ET architecture assigns two learned vectors to each atom type - zi. One is used to encode information specific to an atom, the - other (this class) takes the role of a neighborhood embedding. - The neighborhood embedding, which is an embedding of the types of - neighboring atoms, is multiplied by a distance filter. - + def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=100, dtype=torch.float32): + """ + The ET architecture assigns two learned vectors to each atom type + zi. One is used to encode information specific to an atom, the + other (this class) takes the role of a neighborhood embedding. + The neighborhood embedding, which is an embedding of the types of + neighboring atoms, is multiplied by a distance filter. - This embedding allows the network to store information about the - interaction of atom pairs. - See eq. 3 in https://arxiv.org/pdf/2202.02541.pdf for more details. - """ + This embedding allows the network to store information about the + interaction of atom pairs. - def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=100): + See eq. 3 in https://arxiv.org/pdf/2202.02541.pdf for more details. + """ super(NeighborEmbedding, self).__init__(aggr="add") - self.embedding = nn.Embedding(max_z, hidden_channels) - self.distance_proj = nn.Linear(num_rbf, hidden_channels) - self.combine = nn.Linear(hidden_channels * 2, hidden_channels) + self.embedding = nn.Embedding(max_z, hidden_channels, dtype=dtype) + self.distance_proj = nn.Linear(num_rbf, hidden_channels, dtype=dtype) + self.combine = nn.Linear(hidden_channels * 2, hidden_channels, dtype=dtype) self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.reset_parameters() @@ -266,13 +265,13 @@ def forward( class GaussianSmearing(nn.Module): - def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): + def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32): super(GaussianSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.num_rbf = num_rbf self.trainable = trainable - + self.dtype = dtype offset, coeff = self._initial_params() if trainable: self.register_parameter("coeff", nn.Parameter(coeff)) @@ -282,7 +281,7 @@ def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=Tru self.register_buffer("offset", offset) def _initial_params(self): - offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf) + offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf, dtype=self.dtype) coeff = -0.5 / (offset[1] - offset[0]) ** 2 return offset, coeff @@ -297,13 +296,13 @@ def forward(self, dist): class ExpNormalSmearing(nn.Module): - def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): + def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32): super(ExpNormalSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.num_rbf = num_rbf self.trainable = trainable - + self.dtype = dtype self.cutoff_fn = CosineCutoff(0, cutoff_upper) self.alpha = 5.0 / (cutoff_upper - cutoff_lower) @@ -319,11 +318,11 @@ def _initial_params(self): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 start_value = torch.exp( - torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower) + torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype) ) - means = torch.linspace(start_value, 1, self.num_rbf) + means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype) betas = torch.tensor( - [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf + [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf, dtype=self.dtype ) return means, betas @@ -376,13 +375,13 @@ def forward(self, distances): + 1.0 ) # remove contributions below the cutoff radius - cutoffs = cutoffs * (distances < self.cutoff_upper).float() - cutoffs = cutoffs * (distances > self.cutoff_lower).float() + cutoffs = cutoffs * (distances < self.cutoff_upper) + cutoffs = cutoffs * (distances > self.cutoff_lower) return cutoffs else: cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) # remove contributions beyond the cutoff radius - cutoffs = cutoffs * (distances < self.cutoff_upper).float() + cutoffs = cutoffs * (distances < self.cutoff_upper) return cutoffs @@ -461,6 +460,7 @@ def __init__( intermediate_channels=None, activation="silu", scalar_activation=False, + dtype=torch.float, ): super(GatedEquivariantBlock, self).__init__() self.out_channels = out_channels @@ -468,14 +468,14 @@ def __init__( if intermediate_channels is None: intermediate_channels = hidden_channels - self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False) - self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False) + self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) + self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias=False, dtype=dtype) act_class = act_class_mapping[activation] self.update_net = nn.Sequential( - nn.Linear(hidden_channels * 2, intermediate_channels), + nn.Linear(hidden_channels * 2, intermediate_channels, dtype=dtype), act_class(), - nn.Linear(intermediate_channels, out_channels * 2), + nn.Linear(intermediate_channels, out_channels * 2, dtype=dtype), ) self.act = act_class() if scalar_activation else None @@ -493,7 +493,7 @@ def forward(self, x, v): # detach zero-entries to avoid NaN gradients during force loss backpropagation vec1 = torch.zeros( - vec1_buffer.size(0), vec1_buffer.size(2), device=vec1_buffer.device + vec1_buffer.size(0), vec1_buffer.size(2), device=vec1_buffer.device, dtype=vec1_buffer.dtype ) mask = (vec1_buffer != 0).view(vec1_buffer.size(0), -1).any(dim=1) if not mask.all(): @@ -525,3 +525,5 @@ def forward(self, x, v): "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, } + +dtype_mapping = {"float": torch.float, "double": torch.float64, "float32": torch.float32, "float64": torch.float64} diff --git a/torchmdnet/priors/d2.py b/torchmdnet/priors/d2.py index 91a72aad6..f526c0d1e 100644 --- a/torchmdnet/priors/d2.py +++ b/torchmdnet/priors/d2.py @@ -103,7 +103,8 @@ class D2(BasePrior): [31.74, 1.892], # 52 Te [31.50, 1.892], # 53 I [29.99, 1.881], # 54 Xe - ] + ], dtype=pt.float64 + ) C_6_R_r[:, 1] *= 0.1 # Å --> nm @@ -115,17 +116,21 @@ def __init__( distance_scale=None, energy_scale=None, dataset=None, + dtype=pt.float32, ): super().__init__() - self.cutoff_distance = float(cutoff_distance) + one = pt.tensor(1.0, dtype=dtype).item() + self.cutoff_distance = cutoff_distance * one self.max_num_neighbors = int(max_num_neighbors) + + self.C_6_R_r = self.C_6_R_r.to(dtype=dtype) self.atomic_number = list( dataset.atomic_number if atomic_number is None else atomic_number ) - self.distance_scale = float( + self.distance_scale = one * ( dataset.distance_scale if distance_scale is None else distance_scale ) - self.energy_scale = float( + self.energy_scale = one * ( dataset.energy_scale if energy_scale is None else energy_scale ) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 777be2a19..8f67476f1 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -12,9 +12,9 @@ from torchmdnet.data import DataModule from torchmdnet.models import output_modules from torchmdnet.models.model import create_prior_models -from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping +from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number - +import torch def get_args(): # fmt: off @@ -67,6 +67,7 @@ def get_args(): parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') # architectural args + parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64') parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')