diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e2470dcb2..347702735 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -15,12 +15,14 @@ def __init__( pre_transform=None, pre_filter=None, paths=None, + atomic_numbers=None, max_gradient=None, subsample_molecules=1, ): assert isinstance(paths, (str, list)) - arg_hash = f"{paths}{max_gradient}{subsample_molecules}" + self.atomic_numbers = set([] if atomic_numbers is None else atomic_numbers) + arg_hash = f"{paths}{self.atomic_numbers}{max_gradient}{subsample_molecules}" arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() self.name = f"{self.__class__.__name__}-{arg_hash}" self.paths = paths @@ -180,6 +182,11 @@ def sample_iter(self, mol_ids=False): fq = pt.tensor(mol["formal_charges"], dtype=pt.long) q = fq.sum() + # Keep molecules with specific elements + if self.atomic_numbers: + if not set(z.numpy()).issubset(self.atomic_numbers): + continue + for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(load_confs(mol, n_atoms=len(z))): # Skip samples with large forces @@ -220,6 +227,7 @@ def processed_file_names(self): def process(self): print("Arguments") + print(f" atomic_numbers: {self.atomic_numbers}") print(f" max_gradient: {self.max_gradient} eV/A") print(f" subsample_molecules: {self.subsample_molecules}\n")