Skip to content

Commit a7e0ea6

Browse files
Improve docstring
1 parent aac3305 commit a7e0ea6

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,6 @@ class KroneckerProduct(OpFromGraph):
10171017
Wrapper Op for Kronecker graphs
10181018
"""
10191019

1020-
...
1021-
10221020

10231021
def kron(a, b):
10241022
"""Kronecker product.

pytensor/tensor/rewriting/linalg.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -322,37 +322,32 @@ def local_log_prod_sqr(fgraph, node):
322322
# returns the sign of the prod multiplication.
323323

324324

325-
def local_inv_kron_to_kron_inv(fgraph, node):
326-
# check if we have a kron
327-
# check if parent node is an inv
328-
# if yes, replace with kron(inv, inv)
329-
330-
pass
331-
332-
333-
def local_chol_kron_to_kron_chol(fgraph, node):
334-
# check if we have a kron
335-
# check if parent node is a cholesky
336-
# if yes, replace with kron(cholesky, cholesky)
337-
338-
pass
339-
340-
341325
@register_specialize
342326
@node_rewriter([Blockwise])
343327
def local_lift_through_linalg(fgraph, node):
344328
"""
345-
Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)])
329+
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
330+
that join matrices (KroneckerProduct, BlockDiagonal).
331+
332+
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
333+
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
334+
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
335+
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
336+
matrices.
346337
347338
Parameters
348339
----------
349-
fgraph
350-
node
340+
fgraph: FunctionGraph
341+
Function graph being optimized
342+
node: Apply
343+
Node of the function graph to be optimized
351344
352345
Returns
353346
-------
354-
347+
res: Optional[list[Variable]]
348+
List of optimized variables, or None if no optimization was performed
355349
"""
350+
356351
# TODO: Simplify this if we end up Blockwising KroneckerProduct
357352
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
358353
y = node.inputs[0]

tests/tensor/rewriting/test_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,4 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
386386
]
387387
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
388388

389-
f2(*test_vals)
390-
391389
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)

0 commit comments

Comments
 (0)