11import logging
2+ from collections .abc import Callable
23from typing import cast
34
4- from pytensor .graph import FunctionGraph , Apply
5+ from pytensor import Variable
6+ from pytensor .graph import Apply , FunctionGraph
57from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
68from pytensor .tensor .basic import TensorVariable , diagonal , swapaxes
79from pytensor .tensor .blas import Dot22
@@ -325,7 +327,9 @@ def local_log_prod_sqr(fgraph, node):
325327
326328@register_specialize
327329@node_rewriter ([Blockwise ])
328- def local_lift_through_linalg (fgraph : FunctionGraph , node : Apply ):
330+ def local_lift_through_linalg (
331+ fgraph : FunctionGraph , node : Apply
332+ ) -> list [Variable ] | None :
329333 """
330334 Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
331335 that join matrices (KroneckerProduct, BlockDiagonal).
@@ -345,7 +349,7 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
345349
346350 Returns
347351 -------
348- res: list of Variable, optional
352+ list of Variable, optional
349353 List of optimized variables, or None if no optimization was performed
350354 """
351355
@@ -362,15 +366,15 @@ def local_lift_through_linalg(fgraph: FunctionGraph, node: Apply):
362366 input_matrices = y .owner .inputs
363367
364368 if isinstance (outer_op .core_op , MatrixInverse ):
365- outer_f = inv
369+ outer_f = cast ( Callable , inv )
366370 elif isinstance (outer_op .core_op , Cholesky ):
367- outer_f = cholesky
371+ outer_f = cast ( Callable , cholesky )
368372 elif isinstance (outer_op .core_op , MatrixPinv ):
369- outer_f = pinv
373+ outer_f = cast ( Callable , pinv )
370374 else :
371375 raise NotImplementedError
372376
373- inner_matrices = [outer_f (m ) for m in input_matrices ]
377+ inner_matrices = [cast ( TensorVariable , outer_f (m ) ) for m in input_matrices ]
374378
375379 if isinstance (y .owner .op , KroneckerProduct ):
376380 return [kron (* inner_matrices )]
0 commit comments