Skip to content

Conversation

@Fyrebright
Copy link
Contributor

@Fyrebright Fyrebright commented Sep 2, 2025

Description

Directly use LAPACK functions in the perform method of the following classes, removing some checks from scipy.linalg (e.g. _datacopied and asarray):

  • CholeskySolve (potrs)
  • LUFactor (getrf)
  • SolveTriangular (trtrs)

Add coverage for empty case in each.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify): performance

📚 Documentation preview 📚: https://pytensor--1605.org.readthedocs.build/en/1605/

@codecov
Copy link

codecov bot commented Sep 2, 2025

Codecov Report

❌ Patch coverage is 60.86957% with 18 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.62%. Comparing base (46f8227) to head (1d3e180).
⚠️ Report is 43 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/slinalg.py 60.86% 9 Missing and 9 partials ⚠️

❌ Your patch status has failed because the patch coverage (60.86%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1605      +/-   ##
==========================================
- Coverage   81.64%   81.62%   -0.03%     
==========================================
  Files         231      231              
  Lines       52952    52992      +40     
  Branches     9388     9404      +16     
==========================================
+ Hits        43235    43257      +22     
- Misses       7273     7282       +9     
- Partials     2444     2453       +9     
Files with missing lines Coverage Δ
pytensor/tensor/slinalg.py 91.71% <60.86%> (-1.70%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Fyrebright Fyrebright marked this pull request as ready for review September 2, 2025 16:24
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great, thanks for taking this on!

I have one requested change (the one about shape checking) and the rest is related to what we should do if the algorithm fails. Please don't make those changes until the other devs weigh in, since it's a bit of an API break.

Comment on lines 395 to 398
if c.ndim != 2 or c.shape[0] != c.shape[1]:
raise ValueError("The factored matrix c is not square.")
if c.shape[1] != b.shape[0]:
raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do shape checking in perform, that is handled by make_node

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true, shapes may not be static

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true, shapes may not be static

FWIW; we'll deprecate in-place modifications of the shape (also dtype and strides) modifications in numpy 2.4

Copy link
Member

@ricardoV94 ricardoV94 Sep 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static here means we don't know the shape until runtime, as in the following graph:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(None,))
out = pt.exp(x)

fn = pytensor.function([x], out)
fn([1, 2, 3])
fn([1, 2, 3, 4])  # Allowed to call with different input lengths each time

Comment on lines +406 to +408
if info != 0:
raise ValueError(f"illegal value in {-info}th argument of internal potrs")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if we returned a matrix of np.nan if info !=0 rather than erroring out. This is what jax does, and it makes it a lot more ergonomic to work with in iterative algorithms.

This might be out of scope for this PR; asking @ricardoV94 for a 2nd opinion

Comment on lines +724 to +733
if info < 0:
raise ValueError(
f"illegal value in {-info}th argument of internal getrf (lu_factor)"
)
if info > 0:
warnings.warn(
f"Diagonal number {info} is exactly zero. Singular matrix.",
LinAlgWarning,
stacklevel=2,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

Comment on lines +937 to +942
if info > 0:
raise LinAlgError(
f"singular matrix: resolution failed at diagonal {info-1}"
)
elif info < 0:
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

@Fyrebright
Copy link
Contributor Author

I have one requested change (the one about shape checking) and the rest is related to what we should do if the algorithm fails. Please don't make those changes until the other devs weigh in, since it's a bit of an API break.

Just wanted to check if there was any further thoughts on this?

@jessegrabowski
Copy link
Member

Just wanted to check if there was any further thoughts on this?

Thanks for poking us on this. I had a conversation with @ricardoV94 today about it, and asked him to weigh in here

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 30, 2025

We should check shape in the perform unless the code fails gracefully if the shape is wrong. If you get wrong results silently or segfaults otherwise, that's no good.

Returning nan is fine by me, but perhaps better addressed in a separate PR where we do that for all linalg Ops.

@Fyrebright Fyrebright force-pushed the 1468-chosolve-lufact-triangularsolve branch from a536532 to cebc540 Compare September 30, 2025 14:48
@Fyrebright Fyrebright force-pushed the 1468-chosolve-lufact-triangularsolve branch from cebc540 to 334a44e Compare September 30, 2025 14:57
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

There was a failed test, but it didn't look like it was related to your changes (it was something about MinMax on integers). I'm rerunning it to see if it magically goes away, otherwise will need to look more closely.

if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()):
raise ValueError("array must not contain infs or NaNs")

if c.ndim != 2 or c.shape[0] != c.shape[1]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precise, we don't need to check ndims, but I guess it's fine.

@Fyrebright
Copy link
Contributor Author

Fyrebright commented Oct 1, 2025

There was a failed test, but it didn't look like it was related to your changes (it was something about MinMax on integers). I'm rerunning it to see if it magically goes away, otherwise will need to look more closely.

I was unable to replicate it locally. It should have xfailed (fast compile = 0) if I understand correctly.

@jessegrabowski
Copy link
Member

I was unable to replicate it locally. It should have xfailed (fast compile = 0) if I understand correctly.

Looks like it magically fixed itself :)

@jessegrabowski
Copy link
Member

@ricardoV94 any objections to merging this

@ricardoV94
Copy link
Member

Nope

@jessegrabowski jessegrabowski merged commit f48068a into pymc-devs:main Oct 12, 2025
63 of 64 checks passed
@jessegrabowski
Copy link
Member

@Fyrebright Thanks for work on this!

Copilot AI pushed a commit that referenced this pull request Oct 13, 2025
…#1605)

* Use lapack instead of `scipy_linalg.cho_solve`

* Use lapack instead of `scipy_linalg.lu_factor`

* Use lapack instead of `scipy_linalg.solve_triangular`

* Add empty test for lu_factor

* Tidy imports

* remove ndim check
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants