Skip to content

Commit eb5941e

Browse files
Add rewrite to lift linear algebra through certain linalg ops
1 parent 30b760f commit eb5941e

File tree

4 files changed

+121
-4
lines changed

4 files changed

+121
-4
lines changed

pytensor/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import cast
88

99
import pytensor.tensor as pt
10-
from pytensor import function
10+
from pytensor.compile import function
1111
from pytensor.compile.function.pfunc import rebuild_collect_shared
1212
from pytensor.compile.mode import optdb
1313
from pytensor.compile.sharedvalue import SharedVariable

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
614614
615615
Returns
616616
-------
617-
U, V, D : matrices
617+
U, V, D : matrices
618618
619619
"""
620620
return Blockwise(SVD(full_matrices, compute_uv))(a)

pytensor/tensor/rewriting/linalg.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,21 @@
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.elemwise import DimShuffle
99
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
10-
from pytensor.tensor.nlinalg import MatrixInverse, det
10+
from pytensor.tensor.nlinalg import MatrixInverse, MatrixPinv, det, inv, pinv
1111
from pytensor.tensor.rewriting.basic import (
1212
register_canonicalize,
1313
register_specialize,
1414
register_stabilize,
1515
)
1616
from pytensor.tensor.slinalg import (
17+
BlockDiagonal,
1718
Cholesky,
19+
KroneckerProduct,
1820
Solve,
1921
SolveBase,
22+
block_diag,
2023
cholesky,
24+
kron,
2125
solve,
2226
solve_triangular,
2327
)
@@ -310,3 +314,65 @@ def local_log_prod_sqr(fgraph, node):
310314

311315
# TODO: have a reduction like prod and sum that simply
312316
# returns the sign of the prod multiplication.
317+
318+
319+
def local_inv_kron_to_kron_inv(fgraph, node):
320+
# check if we have a kron
321+
# check if parent node is an inv
322+
# if yes, replace with kron(inv, inv)
323+
324+
pass
325+
326+
327+
def local_chol_kron_to_kron_chol(fgraph, node):
328+
# check if we have a kron
329+
# check if parent node is a cholesky
330+
# if yes, replace with kron(cholesky, cholesky)
331+
332+
pass
333+
334+
335+
@register_specialize
336+
@node_rewriter([Blockwise])
337+
def local_lift_through_linalg(fgraph, node):
338+
"""
339+
Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)])
340+
341+
Parameters
342+
----------
343+
fgraph
344+
node
345+
346+
Returns
347+
-------
348+
349+
"""
350+
# TODO: Simplify this if we end up Blockwising KroneckerProduct
351+
if isinstance(node.op.core_op, (MatrixInverse, Cholesky, MatrixPinv)):
352+
y = node.inputs[0]
353+
outer_op = node.op
354+
355+
if y.owner and (
356+
isinstance(y.owner.op, Blockwise)
357+
and isinstance(y.owner.op.core_op, BlockDiagonal)
358+
or isinstance(y.owner.op, KroneckerProduct)
359+
):
360+
input_matrices = y.owner.inputs
361+
362+
if isinstance(outer_op.core_op, MatrixInverse):
363+
outer_f = inv
364+
elif isinstance(outer_op.core_op, Cholesky):
365+
outer_f = cholesky
366+
elif isinstance(outer_op.core_op, MatrixPinv):
367+
outer_f = pinv
368+
else:
369+
raise NotImplementedError
370+
371+
inner_matrices = [outer_f(m) for m in input_matrices]
372+
373+
if isinstance(y.owner.op, KroneckerProduct):
374+
return [kron(*inner_matrices)]
375+
elif isinstance(y.owner.op.core_op, BlockDiagonal):
376+
return [block_diag(*inner_matrices)]
377+
else:
378+
raise NotImplementedError

tests/tensor/rewriting/test_linalg.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from pytensor.tensor.blockwise import Blockwise
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.math import _allclose, dot, matmul
17-
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
17+
from pytensor.tensor.nlinalg import Det, MatrixInverse, MatrixPinv, matrix_inverse
1818
from pytensor.tensor.rewriting.linalg import inv_as_solve
1919
from pytensor.tensor.slinalg import (
20+
BlockDiagonal,
2021
Cholesky,
22+
KroneckerProduct,
2123
Solve,
2224
SolveBase,
2325
SolveTriangular,
@@ -333,3 +335,52 @@ def test_invalid_batched_a(self):
333335
ref_fn(test_a, test_b),
334336
rtol=1e-7 if config.floatX == "float64" else 1e-5,
335337
)
338+
339+
340+
@pytest.mark.parametrize(
341+
"constructor", [pt.dmatrix, pt.tensor3], ids=["not_batched", "batched"]
342+
)
343+
@pytest.mark.parametrize(
344+
"f_op, f",
345+
[
346+
(MatrixInverse, pt.linalg.inv),
347+
(Cholesky, pt.linalg.cholesky),
348+
(MatrixPinv, pt.linalg.pinv),
349+
],
350+
ids=["inv", "cholesky", "pinv"],
351+
)
352+
@pytest.mark.parametrize(
353+
"g_op, g",
354+
[(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)],
355+
ids=["block_diag", "kron"],
356+
)
357+
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
358+
A, B = list(map(constructor, "ab"))
359+
X = f(g(A, B))
360+
361+
f1 = pytensor.function(
362+
[A, B], X, mode=get_default_mode().including("local_lift_through_linalg")
363+
)
364+
f2 = pytensor.function(
365+
[A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg")
366+
)
367+
368+
all_apply_nodes = f1.maker.fgraph.apply_nodes
369+
f_ops = [
370+
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), f_op)
371+
]
372+
g_ops = [
373+
x for x in all_apply_nodes if isinstance(getattr(x.op, "core_op", x.op), g_op)
374+
]
375+
376+
assert len(f_ops) == 2
377+
assert len(g_ops) == 1
378+
379+
test_vals = [
380+
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
381+
]
382+
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
383+
384+
f2(*test_vals)
385+
386+
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals))

0 commit comments

Comments
 (0)