-
Notifications
You must be signed in to change notification settings - Fork 39
Fairchem v2 support #211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fairchem v2 support #211
Conversation
.github/workflows/test.yml
Outdated
fi | ||
uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system | ||
uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system | ||
uv pip install "torch-scatter" -f https://data.pyg.org/whl/torch-2.6.0+cpu.html --system |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldnt need torch-scatter and torch-sparse anymore, torch needs to be >=2.6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (3)
torch_sim/models/fairchem.py (1)
235-238
: Variable name shadowing –batch
overrides imported function
batch = atomicdata_list_to_batch(...)
shadows the earlier imported
atomicdata_list_to_batch
symbol, hampering interactive debugging and
potential reuse later in the method. Consider renaming the local variable
(e.g.batched_data
) to preserve clarity.tests/models/test_fairchem.py (1)
24-28
:eqv2_uma_model_non_pbc
duplicates PBC fixture without disabling PBCBoth fixtures instantiate the exact same model (
task_name="omat"
)
without changing periodic boundary conditions, so the “non-PBC” variant
is currently indistinguishable. Either drop the duplicate or set
task_name="umat"
(or whatever non-PBC task is appropriate).examples/scripts/1_Introduction/1.3_fairchem.py (1)
27-31
: Example never requests stress but prints it
FairChemModel
defaults tocompute_stress=False
, soresults.get("stress")
will always beNone
, making the stress-printing blocks dead code.
Either setcompute_stress=True
here or drop the conditional prints.model = FairChemModel( model=None, model_name=MODEL_NAME, + compute_stress=True, task_name="omat", # Open Materials task for crystalline systems cpu=False, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
.github/workflows/test.yml
(3 hunks)examples/scripts/1_Introduction/1.3_fairchem.py
(2 hunks)tests/models/test_fairchem.py
(2 hunks)torch_sim/models/fairchem.py
(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- .github/workflows/test.yml
🧰 Additional context used
🧬 Code Graph Analysis (3)
examples/scripts/1_Introduction/1.3_fairchem.py (2)
torch_sim/models/fairchem.py (4)
FairChemModel
(41-50)FairChemModel
(60-255)device
(171-173)dtype
(166-168)torch_sim/io.py (1)
atoms_to_state
(180-245)
tests/models/test_fairchem.py (2)
tests/models/conftest.py (1)
make_validate_model_outputs_test
(125-230)torch_sim/models/fairchem.py (3)
device
(171-173)FairChemModel
(41-50)FairChemModel
(60-255)
torch_sim/models/fairchem.py (5)
torch_sim/models/interface.py (7)
dtype
(119-121)dtype
(124-128)compute_stress
(131-133)compute_stress
(136-140)device
(107-109)device
(112-116)forward
(165-199)torch_sim/state.py (6)
dtype
(142-144)device
(137-139)SimState
(26-312)row_vector_cell
(185-187)row_vector_cell
(190-196)to
(280-294)torch_sim/models/mace.py (1)
forward
(238-368)torch_sim/models/graphpes.py (1)
forward
(177-191)torch_sim/models/mattersim.py (1)
forward
(113-156)
🪛 Pylint (3.3.7)
torch_sim/models/fairchem.py
[refactor] 87-87: Too many arguments (8/5)
(R0913)
[refactor] 175-175: Too many local variables (17/15)
(R0914)
🪛 GitHub Check: codecov/patch
torch_sim/models/fairchem.py
[warning] 131-131: torch_sim/models/fairchem.py#L131
Added line #L131 was not covered by tests
[warning] 144-144: torch_sim/models/fairchem.py#L144
Added line #L144 was not covered by tests
[warning] 163-163: torch_sim/models/fairchem.py#L163
Added line #L163 was not covered by tests
[warning] 249-249: torch_sim/models/fairchem.py#L249
Added line #L249 was not covered by tests
[warning] 251-253: torch_sim/models/fairchem.py#L251-L253
Added lines #L251 - L253 were not covered by tests
⏰ Context from checks skipped due to timeout of 90000ms (31)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: build-docs
torch_sim/models/fairchem.py
Outdated
@property | ||
def dtype(self) -> torch.dtype: | ||
"""Return the data type used by the model.""" | ||
return self._dtype | ||
|
||
Args: | ||
checkpoint_path (str): Path to the trained model checkpoint file | ||
checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, | ||
this dictionary is used instead of loading from checkpoint_path. | ||
|
||
Notes: | ||
If loading fails, a message is printed but no exception is raised. | ||
""" | ||
try: | ||
self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) | ||
except NotImplementedError: | ||
print("Unable to load checkpoint!") | ||
@property | ||
def device(self) -> torch.device: | ||
"""Return the device where the model is located.""" | ||
return self._device | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Public compute_stress
/ compute_forces
attributes are missing — tests will crash
tests/models/conftest.py
expects every ModelInterface
implementation to expose public compute_stress
and compute_forces
flags that are both readable and writable.
The current class only stores the private counterparts (_compute_stress
, _compute_forces
).
pytest
will raise AttributeError
, causing the whole FairChem test suite to fail.
@@
def device(self) -> torch.device:
"""Return the device where the model is located."""
return self._device
+
+ # --------------------------------------------------------------------- #
+ # Public flags required by ModelInterface & test suite
+ # --------------------------------------------------------------------- #
+ @property
+ def compute_stress(self) -> bool:
+ return self._compute_stress
+
+ @compute_stress.setter
+ def compute_stress(self, value: bool) -> None:
+ self._compute_stress = bool(value)
+
+ @property
+ def compute_forces(self) -> bool:
+ return self._compute_forces
+
+ @compute_forces.setter
+ def compute_forces(self, value: bool) -> None:
+ self._compute_forces = bool(value)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@property | |
def dtype(self) -> torch.dtype: | |
"""Return the data type used by the model.""" | |
return self._dtype | |
Args: | |
checkpoint_path (str): Path to the trained model checkpoint file | |
checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, | |
this dictionary is used instead of loading from checkpoint_path. | |
Notes: | |
If loading fails, a message is printed but no exception is raised. | |
""" | |
try: | |
self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) | |
except NotImplementedError: | |
print("Unable to load checkpoint!") | |
@property | |
def device(self) -> torch.device: | |
"""Return the device where the model is located.""" | |
return self._device | |
@property | |
def dtype(self) -> torch.dtype: | |
"""Return the data type used by the model.""" | |
return self._dtype | |
@property | |
def device(self) -> torch.device: | |
"""Return the device where the model is located.""" | |
return self._device | |
# --------------------------------------------------------------------- # | |
# Public flags required by ModelInterface & test suite | |
# --------------------------------------------------------------------- # | |
@property | |
def compute_stress(self) -> bool: | |
return self._compute_stress | |
@compute_stress.setter | |
def compute_stress(self, value: bool) -> None: | |
self._compute_stress = bool(value) | |
@property | |
def compute_forces(self) -> bool: | |
return self._compute_forces | |
@compute_forces.setter | |
def compute_forces(self, value: bool) -> None: | |
self._compute_forces = bool(value) |
🤖 Prompt for AI Agents
In torch_sim/models/fairchem.py around lines 165 to 174, the class defines
private attributes _compute_stress and _compute_forces but lacks corresponding
public properties for them. To fix this, add public @property methods for
compute_stress and compute_forces that return the private attributes, and also
add setter methods to allow these properties to be writable. This will ensure
compatibility with tests/models/conftest.py which expects these flags to be
publicly accessible and mutable.
torch_sim/models/fairchem.py
Outdated
torch.float32: "float32", | ||
torch.float64: "float64", | ||
} | ||
|
||
|
||
class FairChemModel(torch.nn.Module, ModelInterface): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a note that the task name is only specific to UMA models, might be safer to name this something like FairChemUMAModel for future proof (or make the task name optional if we dont want to specialize just for the UMA model here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, thanks for letting me know. i made it optional in 0db53df. i think it's best if FairChemModel
can handle both UMA and earlier models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
torch_sim/models/fairchem.py (3)
80-80
: Implementation addresses the concern about task_name specialization.The
task_name
parameter is already optional (defaults toNone
), which allows the class to work with non-UMA models as suggested in the past review. The conditional conversion toUMATask
only happens when a string is provided.Also applies to: 126-127
144-152
:⚠️ Potential issueMissing required properties will cause test failures.
The class still lacks the public
compute_stress
andcompute_forces
properties required by theModelInterface
contract and test suite.🧰 Tools
🪛 GitHub Check: codecov/patch
[warning] 144-145: torch_sim/models/fairchem.py#L144-L145
Added lines #L144 - L145 were not covered by tests
[warning] 147-147: torch_sim/models/fairchem.py#L147
Added line #L147 was not covered by tests
[warning] 149-150: torch_sim/models/fairchem.py#L149-L150
Added lines #L149 - L150 were not covered by tests
[warning] 152-152: torch_sim/models/fairchem.py#L152
Added line #L152 was not covered by tests
174-176
:⚠️ Potential issueBatch tensor dtype will cause runtime error with
torch.bincount
.The batch tensor must use
torch.long
(64-bit) instead oftorch.int
(32-bit) for compatibility withtorch.bincount
. Additionally, it should be created on the same device asstate.positions
.
🧹 Nitpick comments (1)
tests/models/test_fairchem.py (1)
90-121
: Consider adding negative test cases.The test effectively covers valid task-system combinations. Consider adding a test case that verifies appropriate error handling when incompatible system types are used with specific tasks (e.g., using a molecule with "omat" task).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/models/test_fairchem.py
(1 hunks)torch_sim/models/fairchem.py
(6 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/models/test_fairchem.py (3)
tests/models/conftest.py (1)
make_validate_model_outputs_test
(125-230)torch_sim/models/fairchem.py (4)
FairChemModel
(31-40)FairChemModel
(50-230)device
(150-152)dtype
(145-147)torch_sim/io.py (1)
atoms_to_state
(180-245)
🪛 GitHub Check: codecov/patch
torch_sim/models/fairchem.py
[warning] 23-26: torch_sim/models/fairchem.py#L23-L26
Added lines #L23 - L26 were not covered by tests
[warning] 71-71: torch_sim/models/fairchem.py#L71
Added line #L71 was not covered by tests
[warning] 109-110: torch_sim/models/fairchem.py#L109-L110
Added lines #L109 - L110 were not covered by tests
[warning] 120-120: torch_sim/models/fairchem.py#L120
Added line #L120 was not covered by tests
[warning] 122-123: torch_sim/models/fairchem.py#L122-L123
Added lines #L122 - L123 were not covered by tests
[warning] 126-127: torch_sim/models/fairchem.py#L126-L127
Added lines #L126 - L127 were not covered by tests
[warning] 130-132: torch_sim/models/fairchem.py#L130-L132
Added lines #L130 - L132 were not covered by tests
[warning] 135-135: torch_sim/models/fairchem.py#L135
Added line #L135 was not covered by tests
[warning] 140-142: torch_sim/models/fairchem.py#L140-L142
Added lines #L140 - L142 were not covered by tests
[warning] 144-145: torch_sim/models/fairchem.py#L144-L145
Added lines #L144 - L145 were not covered by tests
[warning] 147-147: torch_sim/models/fairchem.py#L147
Added line #L147 was not covered by tests
[warning] 149-150: torch_sim/models/fairchem.py#L149-L150
Added lines #L149 - L150 were not covered by tests
[warning] 152-152: torch_sim/models/fairchem.py#L152
Added line #L152 was not covered by tests
[warning] 178-178: torch_sim/models/fairchem.py#L178
Added line #L178 was not covered by tests
[warning] 181-181: torch_sim/models/fairchem.py#L181
Added line #L181 was not covered by tests
[warning] 187-189: torch_sim/models/fairchem.py#L187-L189
Added lines #L187 - L189 were not covered by tests
[warning] 196-196: torch_sim/models/fairchem.py#L196
Added line #L196 was not covered by tests
[warning] 204-205: torch_sim/models/fairchem.py#L204-L205
Added lines #L204 - L205 were not covered by tests
[warning] 207-208: torch_sim/models/fairchem.py#L207-L208
Added lines #L207 - L208 were not covered by tests
[warning] 211-212: torch_sim/models/fairchem.py#L211-L212
Added lines #L211 - L212 were not covered by tests
[warning] 215-215: torch_sim/models/fairchem.py#L215
Added line #L215 was not covered by tests
[warning] 219-220: torch_sim/models/fairchem.py#L219-L220
Added lines #L219 - L220 were not covered by tests
[warning] 223-224: torch_sim/models/fairchem.py#L223-L224
Added lines #L223 - L224 were not covered by tests
[warning] 226-228: torch_sim/models/fairchem.py#L226-L228
Added lines #L226 - L228 were not covered by tests
🪛 Pylint (3.3.7)
torch_sim/models/fairchem.py
[refactor] 71-71: Too many arguments (8/5)
(R0913)
[refactor] 154-154: Too many local variables (17/15)
(R0914)
⏰ Context from checks skipped due to timeout of 90000ms (42)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.3_fairchem.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/models/test_mace.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
- GitHub Check: build-docs
🔇 Additional comments (4)
torch_sim/models/fairchem.py (1)
177-230
: Forward method correctly implements the new predictor API.The conversion pipeline from SimState → ASE Atoms → AtomicData → batch prediction → results is well-implemented. The conditional handling of
task_name
and stress tensor shape correction are particularly well done.🧰 Tools
🪛 GitHub Check: codecov/patch
[warning] 178-178: torch_sim/models/fairchem.py#L178
Added line #L178 was not covered by tests
[warning] 181-181: torch_sim/models/fairchem.py#L181
Added line #L181 was not covered by tests
[warning] 187-189: torch_sim/models/fairchem.py#L187-L189
Added lines #L187 - L189 were not covered by tests
[warning] 196-196: torch_sim/models/fairchem.py#L196
Added line #L196 was not covered by tests
[warning] 204-205: torch_sim/models/fairchem.py#L204-L205
Added lines #L204 - L205 were not covered by tests
[warning] 207-208: torch_sim/models/fairchem.py#L207-L208
Added lines #L207 - L208 were not covered by tests
[warning] 211-212: torch_sim/models/fairchem.py#L211-L212
Added lines #L211 - L212 were not covered by tests
[warning] 215-215: torch_sim/models/fairchem.py#L215
Added line #L215 was not covered by tests
[warning] 219-220: torch_sim/models/fairchem.py#L219-L220
Added lines #L219 - L220 were not covered by tests
[warning] 223-224: torch_sim/models/fairchem.py#L223-L224
Added lines #L223 - L224 were not covered by tests
[warning] 226-228: torch_sim/models/fairchem.py#L226-L228
Added lines #L226 - L228 were not covered by teststests/models/test_fairchem.py (3)
41-89
: Well-designed test for homogeneous batching scenarios.The parameterized test effectively covers both material and molecular systems, properly handles molecular properties, and includes a good assertion that different systems produce different energies.
122-165
: Comprehensive batch size testing with good edge case coverage.The test effectively covers single systems, mixed sizes, and large batches (15 systems). The finiteness checks for both energies and forces are particularly valuable for catching numerical issues.
211-219
: Good defensive programming with empty batch test.Testing error handling for edge cases like empty batches is excellent practice. The test appropriately expects multiple possible exception types that could be raised in this scenario.
What's the status on this? Does this work? |
Afaik there are a few people using this branch happily and so I think we can just wait until it can land properly rather than creating additional maintenance workload. What is the expected timeframe and source of delay? |
WalkthroughRefactors FairChem integration to UMA predictor-based API compatible with fairchem-core v2, updates example to use static UMA model name, migrates tests from OCPCalculator to UMA with HF-gated flows, and adjusts CI to install fairchem from PyPI with conditional HuggingFace login in tests/examples. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant TorchSim as TorchSim.FairChemModel
participant UMA as fairchem-core Predictor
participant ASE as ASE/AtomicData
User->>TorchSim: predict(state, task_name)
TorchSim->>ASE: convert SimState -> Atoms -> AtomicData (task_name)
ASE-->>TorchSim: atomicdata_list
TorchSim->>UMA: atomicdata_list_to_batch -> predict(batch)
UMA-->>TorchSim: outputs {energy, forces[, stress]}
TorchSim-->>User: dict(energy, forces[, stress])
sequenceDiagram
participant CI as CI Job
participant PyPI as PyPI
participant HF as HuggingFace
CI->>PyPI: pip install torch cpu, fairchem-core, huggingface_hub, pkg[test]
alt HF_TOKEN set
CI->>HF: huggingface-cli login
end
CI->>CI: run tests/examples (UMA)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Assessment against linked issues
Assessment against linked issues: Out-of-scope changes
Possibly related PRs
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
…els/fairchem.py - Modernize FairChemModel to use new FAIRChemCalculator.from_model_checkpoint() API - Replace deprecated imports (load_config, update_config, model_registry) - Simplify model loading with direct pretrained model name - Add required task_name parameter for model initialization - Remove unused imports and parameters (available_models, local_cache)
- Replace OCPCalculator with FAIRChemCalculator in test fixtures - Remove unused model path fixtures and simplify model initialization - update test parameters for UMA model tolerances
using my personal HF token for now
- parameterized tests for different UMA task names and system configs - tests for homogeneous and heterogeneous batching, ensuring correct energy and force outputs - stress tensor computation tests with conditional checks - test error handling for empty batches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🔭 Outside diff range comments (3)
.github/workflows/test.yml (1)
158-166
: Mirror HF login guard in examples job.Same issue as in tests: avoid failing CI when HF_TOKEN isn’t present on forks or local runs.
- if [[ "${{ matrix.example }}" == *"fairchem"* ]]; then - uv pip install "huggingface_hub[cli]" --system - huggingface-cli login --token "$HF_TOKEN" - fi + if [[ "${{ matrix.example }}" == *"fairchem"* ]]; then + uv pip install "huggingface_hub[cli]" --system + if [ -n "$HF_TOKEN" ]; then + huggingface-cli login --token "$HF_TOKEN" + else + echo "Skipping HuggingFace login for example (no token provided); example may skip FairChem calls." + fi + fitorch_sim/models/fairchem.py (2)
154-170
: StateDict conversion bug: wrong masses shape and potential duplicate kwarg.
- Passing masses=torch.ones_like(state["positions"]) creates a [N, 3] tensor, but masses must be length-N.
- If the incoming dict already contains "masses", this will cause multiple values for the argument.
- if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + if isinstance(state, dict): + sd = dict(state) + if "masses" not in sd: + pos = sd["positions"] + sd["masses"] = torch.ones(pos.shape[0], device=pos.device, dtype=pos.dtype) + state = ts.SimState(**sd)
171-176
: system_idx initialization: ensure correct dtype/device.torch.bincount expects int64 (long). Also keep the tensor on the same device as positions.
- if state.system_idx is None: - state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) + if state.system_idx is None: + state.system_idx = torch.zeros( + state.positions.shape[0], dtype=torch.long, device=state.positions.device + )
🧹 Nitpick comments (4)
torch_sim/models/fairchem.py (3)
121-129
: Nit: Simplify argument validation messaging.Minor clarity: error message mixes “checkpoint_path” vs “model”. For consistency with the parameter names, refer to them uniformly.
- raise RuntimeError( - "model_name and checkpoint_path were both specified, " - "please use only one at a time" - ) + raise RuntimeError( + "Both 'model' and 'model_name' were specified; please provide only one." + )
130-136
: Avoid side effects from setup_logging in a library context.Calling setup_logging() in library code changes global logging configuration and can interfere with downstream consumers/tests. Prefer deferring to application code or making this opt-in.
- setup_logging() + # Avoid global logging side-effects in a library. Let application code configure logging. + # setup_logging()
195-202
: Ensure PBC type is compatible with ASE.If state.pbc is a torch tensor, convert to a Python bool or tuple of bools before passing to ASE.
- atoms = Atoms( + pbc_value = False + if cell is not None: + if isinstance(state.pbc, torch.Tensor): + pbc_value = tuple(bool(b) for b in state.pbc.flatten().tolist()) + else: + pbc_value = state.pbc + + atoms = Atoms( numbers=atomic_numbers, positions=positions, cell=cell, - pbc=state.pbc if cell is not None else False, + pbc=pbc_value, )examples/scripts/1_Introduction/1.3_fairchem.py (1)
26-31
: Remove unused variable.atomic_numbers is defined but unused.
-atomic_numbers = si_dc.get_atomic_numbers()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
.github/workflows/test.yml
(3 hunks)examples/scripts/1_Introduction/1.3_fairchem.py
(2 hunks)tests/models/test_fairchem.py
(1 hunks)torch_sim/models/fairchem.py
(6 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
examples/scripts/1_Introduction/1.3_fairchem.py (2)
torch_sim/models/fairchem.py (4)
FairChemModel
(31-40)FairChemModel
(50-230)device
(150-152)dtype
(145-147)torch_sim/io.py (1)
atoms_to_state
(180-245)
torch_sim/models/fairchem.py (2)
torch_sim/state.py (7)
device
(180-182)SimState
(29-493)row_vector_cell
(286-288)row_vector_cell
(291-297)batch
(218-229)batch
(232-243)to
(381-395)torch_sim/models/interface.py (5)
compute_stress
(102-104)compute_stress
(107-111)device
(78-80)device
(83-87)forward
(136-170)
tests/models/test_fairchem.py (5)
tests/models/conftest.py (2)
make_validate_model_outputs_test
(125-230)test_model_output_validation
(135-226)torch_sim/models/fairchem.py (4)
FairChemModel
(31-40)FairChemModel
(50-230)device
(150-152)dtype
(145-147)torch_sim/models/interface.py (6)
device
(78-80)device
(83-87)dtype
(90-92)dtype
(95-99)compute_stress
(102-104)compute_stress
(107-111)tests/conftest.py (2)
device
(24-25)dtype
(29-30)torch_sim/io.py (1)
atoms_to_state
(180-245)
🪛 Ruff (0.12.2)
torch_sim/models/fairchem.py
178-178: import
should be at the top-level of a file
(PLC0415)
🔇 Additional comments (6)
torch_sim/models/fairchem.py (1)
222-229
: Stress handling: clarify shape normalization and availability.Current reshape covers [B, 9] or [B, 3, 3]? If the backend returns Voigt (6) or flat (9) formats, consider normalizing more explicitly or documenting assumptions.
Would you confirm the predictor returns a [B, 3, 3] or [B, 9] tensor for stress? If it’s Voigt (B, 6), we should convert to a symmetric 3x3. I can add a robust converter if needed.
tests/models/test_fairchem.py (5)
21-25
: Fixture LGTM: UMA PBC model aligned with device.Good use of the device fixture to set cpu flag and the UMA task name.
30-39
: Test gating and task parametrization look solid.Skip-if on HF token avoids CI flakiness, and checking UMATask enum value via .value is robust.
169-183
: Stress toggle coverage is appropriate.Validates both presence and shape under compute_stress True/False. This will catch regressions in implemented_properties and forward’s stress handling.
197-209
: Device consistency check is valuable.Confirms outputs reside on the model’s device; pairs well with atoms_to_state’s device parameter.
214-219
: Empty batch error test is pragmatic.Covers edge cases upstream and in forward. The exception tuple is reasonable given multiple potential fail points.
from ase import Atoms | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Lazy import ASE with helpful error and silence Ruff PLC0415.
Keep the lazy import but handle missing ASE explicitly and silence the linter.
- from ase import Atoms
+ try:
+ from ase import Atoms # noqa: PLC0415
+ except Exception as e: # pragma: no cover - environment-dependent
+ raise ImportError(
+ "ASE is required for FairChemModel.forward(). Please `pip install ase`."
+ ) from e
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from ase import Atoms | |
try: | |
from ase import Atoms # noqa: PLC0415 | |
except Exception as e: # pragma: no cover - environment-dependent | |
raise ImportError( | |
"ASE is required for FairChemModel.forward(). Please `pip install ase`." | |
) from e |
🧰 Tools
🪛 Ruff (0.12.2)
178-178: import
should be at the top-level of a file
(PLC0415)
🤖 Prompt for AI Agents
In torch_sim/models/fairchem.py at around lines 178-179 the ASE import should
remain lazy but handle missing dependency and silence Ruff PLC0415: move the
import inside the function where Atoms is used (keep it lazy), wrap it in a
try/except ImportError and raise a clear error message that instructs the user
to pip install ase (or add it to requirements), and add a linter suppression
comment (e.g. "# noqa: PLC0415") on the import line so Ruff stops flagging the
dynamic import.
closes #193
Enables running the UMA class of models with
torch-sim
.Props to the fairchem team. Our internal implementation is now simpler. In particular, I dropped
load_config
,update_config
,model_registry
available_models
,local_cache
Summary by CodeRabbit
New Features
Refactor
Tests
Chores