4444from pytensor .graph .rewriting .basic import node_rewriter
4545from pytensor .scalar .basic import Ceil , Clip , Floor , RoundHalfToEven
4646from pytensor .scalar .basic import clip as scalar_clip
47- from pytensor .tensor .elemwise import Elemwise
47+ from pytensor .tensor .math import ceil , clip , floor , round_half_to_even
4848from pytensor .tensor .var import TensorConstant
4949
5050from pymc .logprob .abstract import (
@@ -67,7 +67,7 @@ class MeasurableClip(MeasurableElemwise):
6767measurable_clip = MeasurableClip (scalar_clip )
6868
6969
70- @node_rewriter (tracks = [Elemwise ])
70+ @node_rewriter (tracks = [clip ])
7171def find_measurable_clips (fgraph : FunctionGraph , node : Node ) -> Optional [List [MeasurableClip ]]:
7272 # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
7373
@@ -78,9 +78,6 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[Me
7878 if isinstance (node .op , MeasurableClip ):
7979 return None # pragma: no cover
8080
81- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Clip )):
82- return None
83-
8481 clipped_var = node .outputs [0 ]
8582 base_var , lower_bound , upper_bound = node .inputs
8683
@@ -179,7 +176,7 @@ class MeasurableRound(MeasurableElemwise):
179176 valid_scalar_types = (RoundHalfToEven , Floor , Ceil )
180177
181178
182- @node_rewriter (tracks = [Elemwise ])
179+ @node_rewriter (tracks = [ceil , floor , round_half_to_even ])
183180def find_measurable_roundings (fgraph : FunctionGraph , node : Node ) -> Optional [List [MeasurableRound ]]:
184181
185182 rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
@@ -189,12 +186,6 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[Lis
189186 if isinstance (node .op , MeasurableRound ):
190187 return None # pragma: no cover
191188
192- if not (
193- isinstance (node .op , Elemwise )
194- and isinstance (node .op .scalar_op , MeasurableRound .valid_scalar_types )
195- ):
196- return None
197-
198189 (rounded_var ,) = node .outputs
199190 (base_var ,) = node .inputs
200191
0 commit comments