Skip to content

Conversation

thomasloux
Copy link

@thomasloux thomasloux commented Sep 12, 2025

This Fix #247

Summary by CodeRabbit

  • Bug Fixes
    • Corrected cell size handling to ensure consistent behavior across multi-cell configurations, especially for large or heterogeneous cells.
    • Improves stability and accuracy in simulations that trigger half-supercell conditions, reducing edge-case inconsistencies.
    • No changes to the public API or user-facing configuration are required.

Copy link

cla-bot bot commented Sep 12, 2025

We require contributors to sign our Contributor License Agreement (CLA), and we don't have record of your signature. In order for us to review and merge your code, please sign the CLA.

Copy link

coderabbitai bot commented Sep 12, 2025

Walkthrough

The forward path in torch_sim/models/orb.py changes the half_supercell activation condition from using the maximum determinant of batched cell matrices to the minimum determinant. No APIs or signatures change; all other logic and call paths remain the same.

Changes

Cohort / File(s) Summary
Half-supercell gating fix
`torch_sim/models/orb.py`
In OrbModel.forward, gating condition changed from torch.max(torch.det(state.cell)) > 1000 to torch.min(torch.det(state.cell)) > 1000, maintaining self._half_supercell check. No other edits.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant OrbModel
  participant ORBBackend as ORB Backend

  Caller->>OrbModel: forward(state)
  OrbModel->>OrbModel: dets = det(state.cell)  // batched
  alt half_supercell enabled AND min(dets) > 1000
    Note right of OrbModel: Use half_supercell path
    OrbModel->>ORBBackend: compute(..., half_supercell=True)
  else
    Note right of OrbModel: Use regular path
    OrbModel->>ORBBackend: compute(..., half_supercell=False)
  end
  ORBBackend-->>OrbModel: results
  OrbModel-->>Caller: outputs
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks (5 passed)

✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "fix condition half supercell" succinctly and accurately names the primary change—fixing the half_supercell condition—and directly relates to the modifications in the changeset (switching the gating check); it is concise and focused enough for a teammate to understand the main intent.
Linked Issues Check ✅ Passed The PR changes the condition from torch.max(torch.det(state.cell)) > 1000 to torch.min(torch.det(state.cell)) > 1000, which enforces that every crystal in a batched input has volume > 1000 and thus directly satisfies the primary coding objective described in [#247]; the raw summary indicates no other logic or signature changes, so the linked-issue requirement is met.
Out of Scope Changes Check ✅ Passed The provided raw summary shows only the single conditional change inside OrbModel.forward (max → min determinant) with no other file, API, or unrelated logic modifications, so there are no apparent out-of-scope changes in this PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

Poem

A bunny hops through crystal halls,
Checks every cell—no one too small.
Not max but min, the gate now true,
Half-supercells when all accrue.
With tidy paws and volumes tall,
The batch runs smooth—no bumps at all. 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 01782f8 and 166819a.

📒 Files selected for processing (1)
  • torch_sim/models/orb.py (1 hunks)
🔇 Additional comments (1)
torch_sim/models/orb.py (1)

389-393: Fix gating: use absolute determinants when checking supercell volume

Current expression uses torch.min(torch.det(state.cell)) > 1000 — a single negative determinant will make the min negative and the gate false even if all volumes' magnitudes exceed 1000. Replace depending on intended semantics:

  • Smallest magnitude must exceed threshold:
    half_supercell = (torch.det(state.cell).abs().min() > 1000 if self._half_supercell is None else self._half_supercell)
  • Every cell must exceed threshold:
    half_supercell = (torch.all(torch.abs(torch.det(state.cell)) > 1000) if self._half_supercell is None else self._half_supercell)

Location: torch_sim/models/orb.py (around lines 389–393).

Comment on lines 389 to 393
half_supercell = (
torch.max(torch.det(state.cell)) > 1000
torch.min(torch.det(state.cell)) > 1000
if self._half_supercell is None
else self._half_supercell
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Good fix switching max→min; also ensure true volume and a Python bool.

To match the upstream “cell volume > 1000” rule per crystal, take abs(det) (orientation-invariant) and cast to a Python bool to avoid passing a device tensor boolean downstream.

-        half_supercell = (
-            torch.min(torch.det(state.cell)) > 1000
-            if self._half_supercell is None
-            else self._half_supercell
-        )
+        if self._half_supercell is None:
+            with torch.no_grad():
+                vols = torch.abs(torch.linalg.det(state.cell))
+                # works for both batched [B,3,3] and single [3,3]
+                half_supercell = (vols.min() > 1000).item()
+        else:
+            half_supercell = self._half_supercell
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
half_supercell = (
torch.max(torch.det(state.cell)) > 1000
torch.min(torch.det(state.cell)) > 1000
if self._half_supercell is None
else self._half_supercell
)
if self._half_supercell is None:
with torch.no_grad():
vols = torch.abs(torch.linalg.det(state.cell))
# works for both batched [B,3,3] and single [3,3]
half_supercell = (vols.min() > 1000).item()
else:
half_supercell = self._half_supercell
🤖 Prompt for AI Agents
In torch_sim/models/orb.py around lines 389–393, the check uses
torch.min(torch.det(state.cell)) > 1000 but should use the orientation-invariant
absolute cell volume and return a plain Python bool; compute volumes =
torch.abs(torch.det(state.cell)), evaluate the condition torch.min(volumes) >
1000, and convert that scalar tensor to a Python bool (e.g., using .item() or
bool(...)) when self._half_supercell is None so you don't propagate a device
tensor boolean downstream; keep using self._half_supercell when it is not None.

@CompRhys
Copy link
Collaborator

Thanks for the catch! It would also be good to add a sufficiently big structure as a test case and test against this criteria for orb.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Half supercell error when using OrbModel with batched inputs
2 participants