|
1 | 1 | import logging |
2 | 2 | from typing import cast |
3 | 3 |
|
| 4 | +from pytensor.graph import FunctionGraph, Apply |
4 | 5 | from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter |
5 | 6 | from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes |
6 | 7 | from pytensor.tensor.blas import Dot22 |
@@ -322,37 +323,32 @@ def local_log_prod_sqr(fgraph, node): |
322 | 323 | # returns the sign of the prod multiplication. |
323 | 324 |
|
324 | 325 |
|
325 | | -def local_inv_kron_to_kron_inv(fgraph, node): |
326 | | - # check if we have a kron |
327 | | - # check if parent node is an inv |
328 | | - # if yes, replace with kron(inv, inv) |
329 | | - |
330 | | - pass |
331 | | - |
332 | | - |
333 | | -def local_chol_kron_to_kron_chol(fgraph, node): |
334 | | - # check if we have a kron |
335 | | - # check if parent node is a cholesky |
336 | | - # if yes, replace with kron(cholesky, cholesky) |
337 | | - |
338 | | - pass |
339 | | - |
340 | | - |
341 | 326 | @register_specialize |
342 | 327 | @node_rewriter([Blockwise]) |
343 | | -def local_lift_through_linalg(fgraph, node): |
| 328 | +def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply): |
344 | 329 | """ |
345 | | - Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)]) |
| 330 | + Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops |
| 331 | + that join matrices (KroneckerProduct, BlockDiagonal). |
| 332 | +
|
| 333 | + This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix |
| 334 | + operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker |
| 335 | + product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This |
| 336 | + reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component |
| 337 | + matrices. |
346 | 338 |
|
347 | 339 | Parameters |
348 | 340 | ---------- |
349 | | - fgraph |
350 | | - node |
| 341 | + fgraph: FunctionGraph |
| 342 | + Function graph being optimized |
| 343 | + node: Apply |
| 344 | + Node of the function graph to be optimized |
351 | 345 |
|
352 | 346 | Returns |
353 | 347 | ------- |
354 | | -
|
| 348 | + res: list of Variable, optional |
| 349 | + List of optimized variables, or None if no optimization was performed |
355 | 350 | """ |
| 351 | + |
356 | 352 | # TODO: Simplify this if we end up Blockwising KroneckerProduct |
357 | 353 | if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): |
358 | 354 | y = node.inputs[0] |
|
0 commit comments