|
1 | 1 | import logging |
| 2 | +from collections.abc import Callable |
2 | 3 | from typing import cast |
3 | 4 |
|
| 5 | +from pytensor import Variable |
| 6 | +from pytensor.graph import Apply, FunctionGraph |
4 | 7 | from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter |
5 | 8 | from pytensor.tensor.basic import TensorVariable, diagonal |
6 | 9 | from pytensor.tensor.blas import Dot22 |
7 | 10 | from pytensor.tensor.blockwise import Blockwise |
8 | 11 | from pytensor.tensor.elemwise import DimShuffle |
9 | 12 | from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod |
10 | | -from pytensor.tensor.nlinalg import MatrixInverse, det |
| 13 | +from pytensor.tensor.nlinalg import ( |
| 14 | + KroneckerProduct, |
| 15 | + MatrixInverse, |
| 16 | + MatrixPinv, |
| 17 | + det, |
| 18 | + inv, |
| 19 | + kron, |
| 20 | + pinv, |
| 21 | +) |
11 | 22 | from pytensor.tensor.rewriting.basic import ( |
12 | 23 | register_canonicalize, |
13 | 24 | register_specialize, |
14 | 25 | register_stabilize, |
15 | 26 | ) |
16 | 27 | from pytensor.tensor.slinalg import ( |
| 28 | + BlockDiagonal, |
17 | 29 | Cholesky, |
18 | 30 | Solve, |
19 | 31 | SolveBase, |
| 32 | + block_diag, |
20 | 33 | cholesky, |
21 | 34 | solve, |
22 | 35 | solve_triangular, |
@@ -305,3 +318,62 @@ def local_log_prod_sqr(fgraph, node): |
305 | 318 |
|
306 | 319 | # TODO: have a reduction like prod and sum that simply |
307 | 320 | # returns the sign of the prod multiplication. |
| 321 | + |
| 322 | + |
| 323 | +@register_specialize |
| 324 | +@node_rewriter([Blockwise]) |
| 325 | +def local_lift_through_linalg( |
| 326 | + fgraph: FunctionGraph, node: Apply |
| 327 | +) -> list[Variable] | None: |
| 328 | + """ |
| 329 | + Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops |
| 330 | + that join matrices (KroneckerProduct, BlockDiagonal). |
| 331 | +
|
| 332 | + This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix |
| 333 | + operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker |
| 334 | + product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This |
| 335 | + reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component |
| 336 | + matrices. |
| 337 | +
|
| 338 | + Parameters |
| 339 | + ---------- |
| 340 | + fgraph: FunctionGraph |
| 341 | + Function graph being optimized |
| 342 | + node: Apply |
| 343 | + Node of the function graph to be optimized |
| 344 | +
|
| 345 | + Returns |
| 346 | + ------- |
| 347 | + list of Variable, optional |
| 348 | + List of optimized variables, or None if no optimization was performed |
| 349 | + """ |
| 350 | + |
| 351 | + # TODO: Simplify this if we end up Blockwising KroneckerProduct |
| 352 | + if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): |
| 353 | + y = node.inputs[0] |
| 354 | + outer_op = node.op |
| 355 | + |
| 356 | + if y.owner and ( |
| 357 | + isinstance(y.owner.op, Blockwise) |
| 358 | + and isinstance(y.owner.op.core_op, BlockDiagonal) |
| 359 | + or isinstance(y.owner.op, KroneckerProduct) |
| 360 | + ): |
| 361 | + input_matrices = y.owner.inputs |
| 362 | + |
| 363 | + if isinstance(outer_op.core_op, MatrixInverse): |
| 364 | + outer_f = cast(Callable, inv) |
| 365 | + elif isinstance(outer_op.core_op, Cholesky): |
| 366 | + outer_f = cast(Callable, cholesky) |
| 367 | + elif isinstance(outer_op.core_op, MatrixPinv): |
| 368 | + outer_f = cast(Callable, pinv) |
| 369 | + else: |
| 370 | + raise NotImplementedError |
| 371 | + |
| 372 | + inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices] |
| 373 | + |
| 374 | + if isinstance(y.owner.op, KroneckerProduct): |
| 375 | + return [kron(*inner_matrices)] |
| 376 | + elif isinstance(y.owner.op.core_op, BlockDiagonal): |
| 377 | + return [block_diag(*inner_matrices)] |
| 378 | + else: |
| 379 | + raise NotImplementedError |
0 commit comments