4242import pytensor .tensor as at
4343
4444from pytensor .gradient import DisconnectedType , jacobian
45- from pytensor .graph .basic import Apply , Node , Variable
45+ from pytensor .graph .basic import Apply , Node , Variable , clone_replace
4646from pytensor .graph .features import AlreadyThere , Feature
4747from pytensor .graph .fg import FunctionGraph
4848from pytensor .graph .op import Op
4949from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
5050from pytensor .scalar import Add , Exp , Log , Mul
51+ from pytensor .scan .op import Scan
5152from pytensor .tensor .math import add , exp , log , mul
5253from pytensor .tensor .rewriting .basic import (
5354 register_specialize ,
@@ -186,11 +187,94 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
186187 return trans_node .outputs
187188
188189
190+ @node_rewriter (tracks = [Scan ])
191+ def transform_scan_values (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
192+ """Apply transforms to Scan value variables.
193+
194+ This specialized rewrite is needed because Scan replaces the original value variables
195+ by a more complex graph. We want to apply the transform to the original value variable
196+ in this subgraph, leaving the rest intact
197+ """
198+
199+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
200+ values_to_transforms : Optional [TransformValuesMapping ] = getattr (
201+ fgraph , "values_to_transforms" , None
202+ )
203+
204+ if rv_map_feature is None or values_to_transforms is None :
205+ return None # pragma: no cover
206+
207+ rv_vars = []
208+ value_vars = []
209+
210+ for out in node .outputs :
211+ value = rv_map_feature .rv_values .get (out , None )
212+ if value is None :
213+ continue
214+ rv_vars .append (out )
215+ value_vars .append (value )
216+
217+ if not value_vars :
218+ return None
219+
220+ transforms = [
221+ values_to_transforms .get (rv_map_feature .original_values [value ], None )
222+ for value_var in value_vars
223+ ]
224+
225+ if all (transform is None for transform in transforms ):
226+ return None
227+
228+ new_op = _create_transformed_rv_op (node .op , transforms )
229+ trans_node = node .clone ()
230+ trans_node .op = new_op
231+
232+ # We now assume that the old value variable represents the *transformed space*.
233+ # This means that we need to replace all instance of the old value variable
234+ # with "inversely/un-" transformed versions of itself.
235+ for rv_var , value_var , transform in zip (rv_vars , value_vars , transforms ):
236+ rv_var_out_idx = node .outputs .index (rv_var )
237+ trans_node .outputs [rv_var_out_idx ].name = rv_var .name
238+
239+ if transform is None :
240+ continue
241+
242+ # We access the original value variable and apply the transform to that
243+ original_value_var = rv_map_feature .original_values [value_var ]
244+ trans_original_value_var = transform .backward (original_value_var , * trans_node .inputs )
245+
246+ # We then replace the reference to the original value variable in the scan value
247+ # variable by the back-transform projection computed above
248+
249+ # The first input corresponds to the original value variable. We are careful to
250+ # only clone_replace that part of the graph, as we don't want to break the
251+ # mappings between other rvs that are likely to be present in the rest of the
252+ # scan value variable graph
253+ # TODO: Is it true that the original value only appears in the first input
254+ # and that no other RV can appear there?
255+ (trans_original_value_var ,) = clone_replace (
256+ (value_var .owner .inputs [0 ],),
257+ replace = {original_value_var : trans_original_value_var },
258+ )
259+ trans_value_var = value_var .owner .clone_with_new_inputs (
260+ inputs = [trans_original_value_var ] + value_var .owner .inputs [1 :]
261+ ).default_output ()
262+
263+ new_value_var = transformed_variable (trans_value_var , original_value_var )
264+
265+ if value_var .name and getattr (transform , "name" , None ):
266+ new_value_var .name = f"{ value_var .name } _{ transform .name } "
267+
268+ rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
269+
270+ return trans_node .outputs
271+
272+
189273class TransformValuesMapping (Feature ):
190274 r"""A `Feature` that maintains a map between value variables and their transforms."""
191275
192276 def __init__ (self , values_to_transforms ):
193- self .values_to_transforms = values_to_transforms
277+ self .values_to_transforms = values_to_transforms . copy ()
194278
195279 def on_attach (self , fgraph ):
196280 if hasattr (fgraph , "values_to_transforms" ):
@@ -203,6 +287,7 @@ class TransformValuesRewrite(GraphRewriter):
203287 r"""Transforms value variables according to a map."""
204288
205289 transform_rewrite = in2out (transform_values , ignore_newtrees = True )
290+ scan_transform_rewrite = in2out (transform_scan_values , ignore_newtrees = True )
206291
207292 def __init__ (
208293 self ,
@@ -226,7 +311,8 @@ def add_requirements(self, fgraph):
226311 fgraph .attach_feature (values_transforms_feature )
227312
228313 def apply (self , fgraph : FunctionGraph ):
229- return self .transform_rewrite .rewrite (fgraph )
314+ self .transform_rewrite .rewrite (fgraph )
315+ self .scan_transform_rewrite .rewrite (fgraph )
230316
231317
232318class MeasurableTransform (MeasurableElemwise ):
0 commit comments