Skip to content

Commit b6c9692

Browse files
Improve docstring
1 parent aac3305 commit b6c9692

File tree

3 files changed

+17
-25
lines changed

3 files changed

+17
-25
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,6 @@ class KroneckerProduct(OpFromGraph):
10171017
Wrapper Op for Kronecker graphs
10181018
"""
10191019

1020-
...
1021-
10221020

10231021
def kron(a, b):
10241022
"""Kronecker product.

pytensor/tensor/rewriting/linalg.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import cast
33

4+
from pytensor.graph import FunctionGraph, Apply
45
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
56
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
67
from pytensor.tensor.blas import Dot22
@@ -322,37 +323,32 @@ def local_log_prod_sqr(fgraph, node):
322323
# returns the sign of the prod multiplication.
323324

324325

325-
def local_inv_kron_to_kron_inv(fgraph, node):
326-
# check if we have a kron
327-
# check if parent node is an inv
328-
# if yes, replace with kron(inv, inv)
329-
330-
pass
331-
332-
333-
def local_chol_kron_to_kron_chol(fgraph, node):
334-
# check if we have a kron
335-
# check if parent node is a cholesky
336-
# if yes, replace with kron(cholesky, cholesky)
337-
338-
pass
339-
340-
341326
@register_specialize
342327
@node_rewriter([Blockwise])
343-
def local_lift_through_linalg(fgraph, node):
328+
def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
344329
"""
345-
Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)])
330+
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
331+
that join matrices (KroneckerProduct, BlockDiagonal).
332+
333+
This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
334+
operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
335+
product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
336+
reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
337+
matrices.
346338
347339
Parameters
348340
----------
349-
fgraph
350-
node
341+
fgraph: FunctionGraph
342+
Function graph being optimized
343+
node: Apply
344+
Node of the function graph to be optimized
351345
352346
Returns
353347
-------
354-
348+
res: list of Variable, optional
349+
List of optimized variables, or None if no optimization was performed
355350
"""
351+
356352
# TODO: Simplify this if we end up Blockwising KroneckerProduct
357353
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
358354
y = node.inputs[0]

tests/tensor/rewriting/test_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,4 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
386386
]
387387
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
388388

389-
f2(*test_vals)
390-
391389
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)

0 commit comments

Comments
 (0)