Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lambench/metrics/downstream_tasks_metrics.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ elastic:
metrics: [MAE_G_VRH, MAE_K_VRH]
penalty: success_rate
dummy: {"MAE_G_VRH": 67.5431, "MAE_K_VRH": 136.2597}
vacancy:
domain: Inorganic Materials
metrics: [MAE]
dummy: {"MAE": 4.381}
1 change: 1 addition & 0 deletions lambench/metrics/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def process_domain_specific_for_one_model(model: BaseLargeAtomModel):
"neb",
"wiggle150",
"elastic",
"vacancy",
]:
applicability_results[record.task_name] = record.metrics
return applicability_results
Expand Down
5 changes: 5 additions & 0 deletions lambench/models/ase_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def evaluate(
fmax = task.calculator_params.get("fmax", 1e-3)
max_steps = task.calculator_params.get("max_steps", 500)
return {"metrics": run_inference(self, task.test_data, fmax, max_steps)}
elif task.task_name == "vacancy":
from lambench.tasks.calculator.vacancy.vacancy import run_inference

assert task.test_data is not None
return {"metrics": run_inference(self, task.test_data)}
else:
raise NotImplementedError(f"Task {task.task_name} is not implemented.")

Expand Down
3 changes: 3 additions & 0 deletions lambench/tasks/calculator/calculator_tasks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ elastic:
calculator_params:
fmax: 0.001
max_steps: 500
vacancy:
test_data: /bohr/lambench-vacancy-a2xo/v1
calculator_params: null
79 changes: 79 additions & 0 deletions lambench/tasks/calculator/vacancy/vacancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
The test data is retrieved from:
Chem. Mater. 2023, 35, 24, 10619–10634

https://pubs.acs.org/doi/10.1021/acs.chemmater.3c02251

Only 1813 structure pairs are used.

"""

from ase.io import read
import numpy as np
from ase import Atoms
from tqdm import tqdm
from pathlib import Path

from sklearn.metrics import root_mean_squared_error, mean_absolute_error

from lambench.models.ase_models import ASEModel
import logging


def get_oxygen_reference_energy(calc) -> float:
vacuum_size = 30 # Ångströms: Large cell size to ensure vacuum separation
o_o_bond_length = 1.23 # Ångströms: Experimental O-O bond length for O2
cell_vector = vacuum_size
cell = [cell_vector, cell_vector, cell_vector]
center = cell_vector / 2

positions = [
(center, center, center - o_o_bond_length / 2),
(center, center, center + o_o_bond_length / 2),
]

molecular_oxygen = Atoms("O2", positions=positions, cell=cell, pbc=True)
molecular_oxygen.calc = calc
return molecular_oxygen.get_potential_energy() / 2


def run_inference(
model: ASEModel,
test_data: Path,
) -> dict[str, float]:
pristine_structures = read(test_data / "vacancy_pristine_structures.traj", ":")
defect_structures = read(test_data / "vacancy_defect_structures.traj", ":")
labels = np.load(test_data / "vacancy_evf_label.npy")

evf_lab = []
evf_pred = []
calc = model.calc

# Calculate reference energy for oxygen atom
E_o = get_oxygen_reference_energy(calc)

for pristine, defect, label in tqdm(
zip(pristine_structures, defect_structures, labels)
):
natoms_pri = len(pristine)
natoms_def = len(defect)

n_oxygen = natoms_pri - natoms_def

pristine.calc = calc
defect.calc = calc
try:
final = defect.get_potential_energy()
initial = pristine.get_potential_energy()

e_vf = final + n_oxygen * E_o - initial
evf_lab.append(label)
evf_pred.append(e_vf)

except Exception as e:
logging.error(f"Error occurred while processing structures: {e}")

return {
"MAE": mean_absolute_error(evf_lab, evf_pred), # eV
"RMSE": root_mean_squared_error(evf_lab, evf_pred), # eV
}