11import logging
2+ from collections .abc import Callable
23from typing import cast
34
4- from pytensor .graph import FunctionGraph , Apply
5+ from pytensor import Variable
6+ from pytensor .graph import Apply , FunctionGraph
57from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
68from pytensor .tensor .basic import TensorVariable , diagonal
79from pytensor .tensor .blas import Dot22
@@ -320,7 +322,9 @@ def local_log_prod_sqr(fgraph, node):
320322
321323@register_specialize
322324@node_rewriter ([Blockwise ])
323- def local_lift_through_linalg (fgraph : FunctionGraph , node : Apply ):
325+ def local_lift_through_linalg (
326+ fgraph : FunctionGraph , node : Apply
327+ ) -> list [Variable ] | None :
324328 """
325329 Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
326330 that join matrices (KroneckerProduct, BlockDiagonal).
@@ -340,7 +344,7 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
340344
341345 Returns
342346 -------
343- res: list of Variable, optional
347+ list of Variable, optional
344348 List of optimized variables, or None if no optimization was performed
345349 """
346350
@@ -357,15 +361,15 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
357361 input_matrices = y .owner .inputs
358362
359363 if isinstance (outer_op .core_op , MatrixInverse ):
360- outer_f = inv
364+ outer_f = cast ( Callable , inv )
361365 elif isinstance (outer_op .core_op , Cholesky ):
362- outer_f = cholesky
366+ outer_f = cast ( Callable , cholesky )
363367 elif isinstance (outer_op .core_op , MatrixPinv ):
364- outer_f = pinv
368+ outer_f = cast ( Callable , pinv )
365369 else :
366370 raise NotImplementedError
367371
368- inner_matrices = [outer_f (m ) for m in input_matrices ]
372+ inner_matrices = [cast ( TensorVariable , outer_f (m ) ) for m in input_matrices ]
369373
370374 if isinstance (y .owner .op , KroneckerProduct ):
371375 return [kron (* inner_matrices )]
0 commit comments