Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/ET-QM9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ train_size: 110000
trainable_rbf: false
val_size: 10000
weight_decay: 0.0
dtype: float
1 change: 1 addition & 0 deletions tests/priors.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ train_size: 110000
trainable_rbf: false
val_size: 10000
weight_decay: 0.0
dtype: float
39 changes: 33 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
@mark.parametrize("model_name", models.__all__)
@mark.parametrize("use_batch", [True, False])
@mark.parametrize("explicit_q_s", [True, False])
def test_forward(model_name, use_batch, explicit_q_s):
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward(model_name, use_batch, explicit_q_s, dtype):
z, pos, batch = create_example_batch()
model = create_model(load_example_args(model_name, prior_model=None))
pos = pos.to(dtype=dtype)
model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype))
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -26,20 +28,22 @@ def test_forward(model_name, use_batch, explicit_q_s):

@mark.parametrize("model_name", models.__all__)
@mark.parametrize("output_model", output_modules.__all__)
def test_forward_output_modules(model_name, output_model):
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward_output_modules(model_name, output_model, dtype):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model)
args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype)
model = create_model(args)
model(z, pos, batch=batch)


@mark.parametrize("model_name", models.__all__)
def test_forward_torchscript(model_name):
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward_torchscript(model_name, dtype):
if model_name == "tensornet":
pytest.skip("TensorNet does not support torchscript.")
z, pos, batch = create_example_batch()
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
create_model(load_example_args(model_name, remove_prior=True, derivative=True, dtype=dtype))
)
model(z, pos, batch=batch)

Expand Down Expand Up @@ -103,3 +107,26 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
torch.testing.assert_allclose(
deriv, expected[model_name][output_model]["deriv"]
)


@mark.parametrize("model_name", models.__all__)
def test_gradients(model_name):
pl.seed_everything(1234)
dtype = torch.float64
output_model = "Scalar"
# create model and sample batch
derivative = output_model in ["Scalar", "EquivariantScalar"]
args = load_example_args(
model_name,
remove_prior=True,
output_model=output_model,
derivative=derivative,
dtype=dtype,
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
pos.requires_grad_(True)
pos = pos.to(dtype)
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)
13 changes: 8 additions & 5 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def compute_interaction(pos1, pos2, z1, z2):
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
torch.testing.assert_allclose(expected, energy)

def test_coulomb():
pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=torch.float32) # Atom positions in nm
charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=torch.float32) # Partial charges
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_coulomb(dtype):
pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=dtype) # Atom positions in nm
charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=dtype) # Partial charges
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
distance_scale = 1e-9 # Convert nm to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
Expand All @@ -89,12 +90,14 @@ def compute_interaction(pos1, pos2, z1, z2):
expected += compute_interaction(pos[i], pos[j], charge[i], charge[j])
torch.testing.assert_allclose(expected, energy)

def test_multiple_priors():

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_multiple_priors(dtype):
# Create a model from a config file.

dataset = DummyDataset(has_atomref=True)
config_file = join(dirname(__file__), 'priors.yaml')
args = load_example_args('equivariant-transformer', config_file=config_file)
args = load_example_args('equivariant-transformer', config_file=config_file, dtype=dtype)
prior_models = create_prior_models(args, dataset)
args['prior_args'] = [p.get_init_args() for p in prior_models]
model = LNNP(args, prior_model=prior_models)
Expand Down
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
with open(config_file, "r") as f:
args = yaml.load(f, Loader=yaml.FullLoader)
if "dtype" not in args:
args["dtype"] = "float"
args["model"] = model_name
args["seed"] = 1234
if remove_prior:
Expand Down
30 changes: 25 additions & 5 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,26 @@
from pytorch_lightning.utilities import rank_zero_warn
from torchmdnet.models import output_modules
from torchmdnet.models.wrappers import AtomFilter
from torchmdnet.models.utils import dtype_mapping
from torchmdnet import priors
import warnings


def create_model(args, prior_model=None, mean=None, std=None):
"""Create a model from the given arguments.
See :func:`get_args` in scripts/train.py for a description of the arguments.
Parameters
----------
args (dict): Arguments for the model.
prior_model (nn.Module, optional): Prior model to use. Defaults to None.
mean (torch.Tensor, optional): Mean of the training data. Defaults to None.
std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None.
Returns
-------
nn.Module: An instance of the TorchMD_Net model.
"""
args["dtype"] = "float32" if "dtype" not in args else args["dtype"]
args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"]
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
Expand All @@ -23,6 +38,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
max_num_neighbors=args["max_num_neighbors"],
dtype=args["dtype"]
)

# representation network
Expand Down Expand Up @@ -86,6 +102,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=args["dtype"],
)

# combine representation and output network
Expand All @@ -96,6 +113,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
mean=mean,
std=std,
derivative=args["derivative"],
dtype=args["dtype"],
)
return model

Expand Down Expand Up @@ -168,10 +186,11 @@ def __init__(
mean=None,
std=None,
derivative=False,
dtype=torch.float32,
):
super(TorchMD_Net, self).__init__()
self.representation_model = representation_model
self.output_model = output_model
self.representation_model = representation_model.to(dtype=dtype)
self.output_model = output_model.to(dtype=dtype)

if not output_model.allow_prior_model and prior_model is not None:
prior_model = None
Expand All @@ -183,14 +202,14 @@ def __init__(
)
if isinstance(prior_model, priors.base.BasePrior):
prior_model = [prior_model]
self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model)
self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model).to(dtype=dtype)

self.derivative = derivative

mean = torch.scalar_tensor(0) if mean is None else mean
self.register_buffer("mean", mean)
self.register_buffer("mean", mean.to(dtype=dtype))
std = torch.scalar_tensor(1) if std is None else std
self.register_buffer("std", std)
self.register_buffer("std", std.to(dtype=dtype))

self.reset_parameters()

Expand Down Expand Up @@ -259,6 +278,7 @@ def forward(
)[0]
if dy is None:
raise RuntimeError("Autograd returned None for the force prediction.")

return y, -dy
# TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180)
return y, None
33 changes: 18 additions & 15 deletions torchmdnet/models/output_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ def __init__(
activation="silu",
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float
):
super(Scalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2),
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
)

self.reset_parameters()
Expand All @@ -68,6 +69,7 @@ def __init__(
activation="silu",
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float
):
super(EquivariantScalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
Expand All @@ -79,8 +81,9 @@ def __init__(
hidden_channels // 2,
activation=activation,
scalar_activation=True,
dtype=dtype
),
GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation),
GatedEquivariantBlock(hidden_channels // 2, 1, activation=activation, dtype=dtype),
]
)

Expand All @@ -98,11 +101,11 @@ def pre_reduce(self, x, v, z, pos, batch):


class DipoleMoment(Scalar):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum"):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float):
super(DipoleMoment, self).__init__(
hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op
hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype
)
atomic_mass = torch.from_numpy(atomic_masses).float()
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)

def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
Expand All @@ -119,11 +122,11 @@ def post_reduce(self, x):


class EquivariantDipoleMoment(EquivariantScalar):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum"):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float):
super(EquivariantDipoleMoment, self).__init__(
hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op
hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op, dtype=dtype
)
atomic_mass = torch.from_numpy(atomic_masses).float()
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)

def pre_reduce(self, x, v, z, pos, batch):
Expand All @@ -141,17 +144,17 @@ def post_reduce(self, x):


class ElectronicSpatialExtent(OutputModel):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum"):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float):
super(ElectronicSpatialExtent, self).__init__(
allow_prior_model=False, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2),
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
)
atomic_mass = torch.from_numpy(atomic_masses).float()
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
self.register_buffer("atomic_mass", atomic_mass)

self.reset_parameters()
Expand All @@ -178,9 +181,9 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent):


class EquivariantVectorOutput(EquivariantScalar):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum"):
def __init__(self, hidden_channels, activation="silu", reduce_op="sum", dtype=torch.float):
super(EquivariantVectorOutput, self).__init__(
hidden_channels, activation, allow_prior_model=False, reduce_op="sum"
hidden_channels, activation, allow_prior_model=False, reduce_op="sum", dtype=dtype
)

def pre_reduce(self, x, v, z, pos, batch):
Expand Down
Loading