-
Notifications
You must be signed in to change notification settings - Fork 39
fix condition half supercell #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
We require contributors to sign our Contributor License Agreement (CLA), and we don't have record of your signature. In order for us to review and merge your code, please sign the CLA. |
WalkthroughThe forward path in torch_sim/models/orb.py changes the half_supercell activation condition from using the maximum determinant of batched cell matrices to the minimum determinant. No APIs or signatures change; all other logic and call paths remain the same. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant OrbModel
participant ORBBackend as ORB Backend
Caller->>OrbModel: forward(state)
OrbModel->>OrbModel: dets = det(state.cell) // batched
alt half_supercell enabled AND min(dets) > 1000
Note right of OrbModel: Use half_supercell path
OrbModel->>ORBBackend: compute(..., half_supercell=True)
else
Note right of OrbModel: Use regular path
OrbModel->>ORBBackend: compute(..., half_supercell=False)
end
ORBBackend-->>OrbModel: results
OrbModel-->>Caller: outputs
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Pre-merge checks (5 passed)✅ Passed checks (5 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/models/orb.py
(1 hunks)
🔇 Additional comments (1)
torch_sim/models/orb.py (1)
389-393
: Fix gating: use absolute determinants when checking supercell volumeCurrent expression uses
torch.min(torch.det(state.cell)) > 1000
— a single negative determinant will make the min negative and the gate false even if all volumes' magnitudes exceed 1000. Replace depending on intended semantics:
- Smallest magnitude must exceed threshold:
half_supercell = (torch.det(state.cell).abs().min() > 1000 if self._half_supercell is None else self._half_supercell)
- Every cell must exceed threshold:
half_supercell = (torch.all(torch.abs(torch.det(state.cell)) > 1000) if self._half_supercell is None else self._half_supercell)
Location: torch_sim/models/orb.py (around lines 389–393).
half_supercell = ( | ||
torch.max(torch.det(state.cell)) > 1000 | ||
torch.min(torch.det(state.cell)) > 1000 | ||
if self._half_supercell is None | ||
else self._half_supercell | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Good fix switching max→min; also ensure true volume and a Python bool.
To match the upstream “cell volume > 1000” rule per crystal, take abs(det) (orientation-invariant) and cast to a Python bool to avoid passing a device tensor boolean downstream.
- half_supercell = (
- torch.min(torch.det(state.cell)) > 1000
- if self._half_supercell is None
- else self._half_supercell
- )
+ if self._half_supercell is None:
+ with torch.no_grad():
+ vols = torch.abs(torch.linalg.det(state.cell))
+ # works for both batched [B,3,3] and single [3,3]
+ half_supercell = (vols.min() > 1000).item()
+ else:
+ half_supercell = self._half_supercell
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
half_supercell = ( | |
torch.max(torch.det(state.cell)) > 1000 | |
torch.min(torch.det(state.cell)) > 1000 | |
if self._half_supercell is None | |
else self._half_supercell | |
) | |
if self._half_supercell is None: | |
with torch.no_grad(): | |
vols = torch.abs(torch.linalg.det(state.cell)) | |
# works for both batched [B,3,3] and single [3,3] | |
half_supercell = (vols.min() > 1000).item() | |
else: | |
half_supercell = self._half_supercell |
🤖 Prompt for AI Agents
In torch_sim/models/orb.py around lines 389–393, the check uses
torch.min(torch.det(state.cell)) > 1000 but should use the orientation-invariant
absolute cell volume and return a plain Python bool; compute volumes =
torch.abs(torch.det(state.cell)), evaluate the condition torch.min(volumes) >
1000, and convert that scalar tensor to a Python bool (e.g., using .item() or
bool(...)) when self._half_supercell is None so you don't propagate a device
tensor boolean downstream; keep using self._half_supercell when it is not None.
Thanks for the catch! It would also be good to add a sufficiently big structure as a test case and test against this criteria for orb. |
This Fix #247
Summary by CodeRabbit