|
1 | 1 | import logging |
2 | 2 | from typing import cast |
3 | 3 |
|
| 4 | +from pytensor.graph import FunctionGraph, Apply |
4 | 5 | from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter |
5 | 6 | from pytensor.tensor.basic import TensorVariable, diagonal |
6 | 7 | from pytensor.tensor.blas import Dot22 |
@@ -317,37 +318,32 @@ def local_log_prod_sqr(fgraph, node): |
317 | 318 | # returns the sign of the prod multiplication. |
318 | 319 |
|
319 | 320 |
|
320 | | -def local_inv_kron_to_kron_inv(fgraph, node): |
321 | | - # check if we have a kron |
322 | | - # check if parent node is an inv |
323 | | - # if yes, replace with kron(inv, inv) |
324 | | - |
325 | | - pass |
326 | | - |
327 | | - |
328 | | -def local_chol_kron_to_kron_chol(fgraph, node): |
329 | | - # check if we have a kron |
330 | | - # check if parent node is a cholesky |
331 | | - # if yes, replace with kron(cholesky, cholesky) |
332 | | - |
333 | | - pass |
334 | | - |
335 | | - |
336 | 321 | @register_specialize |
337 | 322 | @node_rewriter([Blockwise]) |
338 | | -def local_lift_through_linalg(fgraph, node): |
| 323 | +def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply): |
339 | 324 | """ |
340 | | - Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)]) |
| 325 | + Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops |
| 326 | + that join matrices (KroneckerProduct, BlockDiagonal). |
| 327 | +
|
| 328 | + This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix |
| 329 | + operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker |
| 330 | + product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This |
| 331 | + reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component |
| 332 | + matrices. |
341 | 333 |
|
342 | 334 | Parameters |
343 | 335 | ---------- |
344 | | - fgraph |
345 | | - node |
| 336 | + fgraph: FunctionGraph |
| 337 | + Function graph being optimized |
| 338 | + node: Apply |
| 339 | + Node of the function graph to be optimized |
346 | 340 |
|
347 | 341 | Returns |
348 | 342 | ------- |
349 | | -
|
| 343 | + res: list of Variable, optional |
| 344 | + List of optimized variables, or None if no optimization was performed |
350 | 345 | """ |
| 346 | + |
351 | 347 | # TODO: Simplify this if we end up Blockwising KroneckerProduct |
352 | 348 | if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): |
353 | 349 | y = node.inputs[0] |
|
0 commit comments