@@ -322,37 +322,32 @@ def local_log_prod_sqr(fgraph, node):
322322 # returns the sign of the prod multiplication.
323323
324324
325- def local_inv_kron_to_kron_inv (fgraph , node ):
326- # check if we have a kron
327- # check if parent node is an inv
328- # if yes, replace with kron(inv, inv)
329-
330- pass
331-
332-
333- def local_chol_kron_to_kron_chol (fgraph , node ):
334- # check if we have a kron
335- # check if parent node is a cholesky
336- # if yes, replace with kron(cholesky, cholesky)
337-
338- pass
339-
340-
341325@register_specialize
342326@node_rewriter ([Blockwise ])
343327def local_lift_through_linalg (fgraph , node ):
344328 """
345- Rewrite a graph like Inv(BlockDiag([A, B, C])) to BlockDiag([Inv(A), Inv(B), Inv(C)])
329+ Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
330+ that join matrices (KroneckerProduct, BlockDiagonal).
331+
332+ This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix
333+ operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker
334+ product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This
335+ reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component
336+ matrices.
346337
347338 Parameters
348339 ----------
349- fgraph
350- node
340+ fgraph: FunctionGraph
341+ Function graph being optimized
342+ node: Apply
343+ Node of the function graph to be optimized
351344
352345 Returns
353346 -------
354-
347+ res: Optional[list[Variable]]
348+ List of optimized variables, or None if no optimization was performed
355349 """
350+
356351 # TODO: Simplify this if we end up Blockwising KroneckerProduct
357352 if isinstance (node .op .core_op , MatrixInverse | Cholesky | MatrixPinv ):
358353 y = node .inputs [0 ]
0 commit comments