-
Notifications
You must be signed in to change notification settings - Fork 148
Use LAPACK functions for cho_solve, lu_factor, solve_triangular
#1605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use LAPACK functions for cho_solve, lu_factor, solve_triangular
#1605
Conversation
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (60.86%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1605 +/- ##
==========================================
- Coverage 81.64% 81.62% -0.03%
==========================================
Files 231 231
Lines 52952 52992 +40
Branches 9388 9404 +16
==========================================
+ Hits 43235 43257 +22
- Misses 7273 7282 +9
- Partials 2444 2453 +9
🚀 New features to boost your workflow:
|
jessegrabowski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really great, thanks for taking this on!
I have one requested change (the one about shape checking) and the rest is related to what we should do if the algorithm fails. Please don't make those changes until the other devs weigh in, since it's a bit of an API break.
pytensor/tensor/slinalg.py
Outdated
| if c.ndim != 2 or c.shape[0] != c.shape[1]: | ||
| raise ValueError("The factored matrix c is not square.") | ||
| if c.shape[1] != b.shape[0]: | ||
| raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to do shape checking in perform, that is handled by make_node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not true, shapes may not be static
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not true, shapes may not be static
FWIW; we'll deprecate in-place modifications of the shape (also dtype and strides) modifications in numpy 2.4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Static here means we don't know the shape until runtime, as in the following graph:
import pytensor
import pytensor.tensor as pt
x = pt.vector("x", shape=(None,))
out = pt.exp(x)
fn = pytensor.function([x], out)
fn([1, 2, 3])
fn([1, 2, 3, 4]) # Allowed to call with different input lengths each time| if info != 0: | ||
| raise ValueError(f"illegal value in {-info}th argument of internal potrs") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer if we returned a matrix of np.nan if info !=0 rather than erroring out. This is what jax does, and it makes it a lot more ergonomic to work with in iterative algorithms.
This might be out of scope for this PR; asking @ricardoV94 for a 2nd opinion
| if info < 0: | ||
| raise ValueError( | ||
| f"illegal value in {-info}th argument of internal getrf (lu_factor)" | ||
| ) | ||
| if info > 0: | ||
| warnings.warn( | ||
| f"Diagonal number {info} is exactly zero. Singular matrix.", | ||
| LinAlgWarning, | ||
| stacklevel=2, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
| if info > 0: | ||
| raise LinAlgError( | ||
| f"singular matrix: resolution failed at diagonal {info-1}" | ||
| ) | ||
| elif info < 0: | ||
| raise ValueError(f"illegal value in {-info}-th argument of internal trtrs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
Just wanted to check if there was any further thoughts on this? |
Thanks for poking us on this. I had a conversation with @ricardoV94 today about it, and asked him to weigh in here |
|
We should check shape in the perform unless the code fails gracefully if the shape is wrong. If you get wrong results silently or segfaults otherwise, that's no good. Returning |
a536532 to
cebc540
Compare
cebc540 to
334a44e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a failed test, but it didn't look like it was related to your changes (it was something about MinMax on integers). I'm rerunning it to see if it magically goes away, otherwise will need to look more closely.
pytensor/tensor/slinalg.py
Outdated
| if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()): | ||
| raise ValueError("array must not contain infs or NaNs") | ||
|
|
||
| if c.ndim != 2 or c.shape[0] != c.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be precise, we don't need to check ndims, but I guess it's fine.
I was unable to replicate it locally. It should have xfailed (fast compile = 0) if I understand correctly. |
Looks like it magically fixed itself :) |
|
@ricardoV94 any objections to merging this |
|
Nope |
|
@Fyrebright Thanks for work on this! |
…#1605) * Use lapack instead of `scipy_linalg.cho_solve` * Use lapack instead of `scipy_linalg.lu_factor` * Use lapack instead of `scipy_linalg.solve_triangular` * Add empty test for lu_factor * Tidy imports * remove ndim check
Description
Directly use LAPACK functions in the
performmethod of the following classes, removing some checks fromscipy.linalg(e.g._datacopiedandasarray):Add coverage for empty case in each.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1605.org.readthedocs.build/en/1605/