diff --git a/lambench/metrics/downstream_tasks_metrics.yml b/lambench/metrics/downstream_tasks_metrics.yml index 2d8ecc09..e31d8a7b 100644 --- a/lambench/metrics/downstream_tasks_metrics.yml +++ b/lambench/metrics/downstream_tasks_metrics.yml @@ -20,3 +20,7 @@ elastic: metrics: [MAE_G_VRH, MAE_K_VRH] penalty: success_rate dummy: {"MAE_G_VRH": 67.5431, "MAE_K_VRH": 136.2597} +vacancy: + domain: Inorganic Materials + metrics: [MAE] + dummy: {"MAE": 4.381} diff --git a/lambench/metrics/post_process.py b/lambench/metrics/post_process.py index 75c832e6..80cb9cab 100644 --- a/lambench/metrics/post_process.py +++ b/lambench/metrics/post_process.py @@ -117,6 +117,7 @@ def process_domain_specific_for_one_model(model: BaseLargeAtomModel): "neb", "wiggle150", "elastic", + "vacancy", ]: applicability_results[record.task_name] = record.metrics return applicability_results diff --git a/lambench/models/ase_models.py b/lambench/models/ase_models.py index 15b1b765..43e898a7 100644 --- a/lambench/models/ase_models.py +++ b/lambench/models/ase_models.py @@ -258,6 +258,11 @@ def evaluate( fmax = task.calculator_params.get("fmax", 1e-3) max_steps = task.calculator_params.get("max_steps", 500) return {"metrics": run_inference(self, task.test_data, fmax, max_steps)} + elif task.task_name == "vacancy": + from lambench.tasks.calculator.vacancy.vacancy import run_inference + + assert task.test_data is not None + return {"metrics": run_inference(self, task.test_data)} else: raise NotImplementedError(f"Task {task.task_name} is not implemented.") diff --git a/lambench/tasks/calculator/calculator_tasks.yml b/lambench/tasks/calculator/calculator_tasks.yml index d9869abd..ff33c49c 100644 --- a/lambench/tasks/calculator/calculator_tasks.yml +++ b/lambench/tasks/calculator/calculator_tasks.yml @@ -28,3 +28,6 @@ elastic: calculator_params: fmax: 0.001 max_steps: 500 +vacancy: + test_data: /bohr/lambench-vacancy-a2xo/v1 + calculator_params: null diff --git a/lambench/tasks/calculator/vacancy/vacancy.py b/lambench/tasks/calculator/vacancy/vacancy.py new file mode 100644 index 00000000..5edcdd97 --- /dev/null +++ b/lambench/tasks/calculator/vacancy/vacancy.py @@ -0,0 +1,79 @@ +""" +The test data is retrieved from: +Chem. Mater. 2023, 35, 24, 10619–10634 + +https://pubs.acs.org/doi/10.1021/acs.chemmater.3c02251 + +Only 1813 structure pairs are used. + +""" + +from ase.io import read +import numpy as np +from ase import Atoms +from tqdm import tqdm +from pathlib import Path + +from sklearn.metrics import root_mean_squared_error, mean_absolute_error + +from lambench.models.ase_models import ASEModel +import logging + + +def get_oxygen_reference_energy(calc) -> float: + vacuum_size = 30 # Ångströms: Large cell size to ensure vacuum separation + o_o_bond_length = 1.23 # Ångströms: Experimental O-O bond length for O2 + cell_vector = vacuum_size + cell = [cell_vector, cell_vector, cell_vector] + center = cell_vector / 2 + + positions = [ + (center, center, center - o_o_bond_length / 2), + (center, center, center + o_o_bond_length / 2), + ] + + molecular_oxygen = Atoms("O2", positions=positions, cell=cell, pbc=True) + molecular_oxygen.calc = calc + return molecular_oxygen.get_potential_energy() / 2 + + +def run_inference( + model: ASEModel, + test_data: Path, +) -> dict[str, float]: + pristine_structures = read(test_data / "vacancy_pristine_structures.traj", ":") + defect_structures = read(test_data / "vacancy_defect_structures.traj", ":") + labels = np.load(test_data / "vacancy_evf_label.npy") + + evf_lab = [] + evf_pred = [] + calc = model.calc + + # Calculate reference energy for oxygen atom + E_o = get_oxygen_reference_energy(calc) + + for pristine, defect, label in tqdm( + zip(pristine_structures, defect_structures, labels) + ): + natoms_pri = len(pristine) + natoms_def = len(defect) + + n_oxygen = natoms_pri - natoms_def + + pristine.calc = calc + defect.calc = calc + try: + final = defect.get_potential_energy() + initial = pristine.get_potential_energy() + + e_vf = final + n_oxygen * E_o - initial + evf_lab.append(label) + evf_pred.append(e_vf) + + except Exception as e: + logging.error(f"Error occurred while processing structures: {e}") + + return { + "MAE": mean_absolute_error(evf_lab, evf_pred), # eV + "RMSE": root_mean_squared_error(evf_lab, evf_pred), # eV + }