From 223037f984e68bd42e318758296ee3bb33f35026 Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Wed, 19 Jul 2023 18:49:01 +0200 Subject: [PATCH 1/2] Implement element filtering in the Ace datasets --- torchmdnet/datasets/ace.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e2470dcb2..64197fbaa 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -15,15 +15,17 @@ def __init__( pre_transform=None, pre_filter=None, paths=None, + atomic_numbers=None, max_gradient=None, subsample_molecules=1, ): assert isinstance(paths, (str, list)) - arg_hash = f"{paths}{max_gradient}{subsample_molecules}" + arg_hash = f"{paths}{atomic_numbers}{max_gradient}{subsample_molecules}" arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() self.name = f"{self.__class__.__name__}-{arg_hash}" self.paths = paths + self.atomic_numbers = atomic_numbers self.max_gradient = max_gradient self.subsample_molecules = int(subsample_molecules) super().__init__(root, transform, pre_transform, pre_filter) @@ -180,6 +182,11 @@ def sample_iter(self, mol_ids=False): fq = pt.tensor(mol["formal_charges"], dtype=pt.long) q = fq.sum() + # Keep molecules with specific elements + if self.atomic_numbers: + if not set(z.numpy()).issubset(self.atomic_numbers): + continue + for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(load_confs(mol, n_atoms=len(z))): # Skip samples with large forces @@ -220,6 +227,7 @@ def processed_file_names(self): def process(self): print("Arguments") + print(f" atomic_numbers: {self.atomic_numbers}") print(f" max_gradient: {self.max_gradient} eV/A") print(f" subsample_molecules: {self.subsample_molecules}\n") From 6f5478b8f95b484402423b0356364dce6d5514df Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Fri, 29 Sep 2023 14:44:24 +0200 Subject: [PATCH 2/2] Store seletected atomic_numbers as a set --- torchmdnet/datasets/ace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 64197fbaa..347702735 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -21,11 +21,11 @@ def __init__( ): assert isinstance(paths, (str, list)) - arg_hash = f"{paths}{atomic_numbers}{max_gradient}{subsample_molecules}" + self.atomic_numbers = set([] if atomic_numbers is None else atomic_numbers) + arg_hash = f"{paths}{self.atomic_numbers}{max_gradient}{subsample_molecules}" arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() self.name = f"{self.__class__.__name__}-{arg_hash}" self.paths = paths - self.atomic_numbers = atomic_numbers self.max_gradient = max_gradient self.subsample_molecules = int(subsample_molecules) super().__init__(root, transform, pre_transform, pre_filter)