3737import abc
3838
3939from copy import copy
40- from typing import Callable , Dict , List , Optional , Tuple , Union
40+ from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
4141
4242import pytensor .tensor as at
4343
@@ -133,43 +133,55 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
133133 ``Y`` on the natural scale.
134134 """
135135
136- rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
137- values_to_transforms = getattr (fgraph , "values_to_transforms" , None )
136+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
137+ values_to_transforms : Optional [TransformValuesMapping ] = getattr (
138+ fgraph , "values_to_transforms" , None
139+ )
138140
139141 if rv_map_feature is None or values_to_transforms is None :
140142 return None # pragma: no cover
141143
142- try :
143- rv_var = node .default_output ()
144- rv_var_out_idx = node .outputs .index (rv_var )
145- except ValueError :
146- return None
144+ rv_vars = []
145+ value_vars = []
147146
148- value_var = rv_map_feature .rv_values .get (rv_var , None )
149- if value_var is None :
147+ for out in node .outputs :
148+ value = rv_map_feature .rv_values .get (out , None )
149+ if value is None :
150+ continue
151+ rv_vars .append (out )
152+ value_vars .append (value )
153+
154+ if not value_vars :
150155 return None
151156
152- transform = values_to_transforms .get (value_var , None )
157+ transforms = [ values_to_transforms .get (value_var , None ) for value_var in value_vars ]
153158
154- if transform is None :
159+ if all ( transform is None for transform in transforms ) :
155160 return None
156161
157- new_op = _create_transformed_rv_op (node .op , transform )
162+ new_op = _create_transformed_rv_op (node .op , transforms )
158163 # Create a new `Apply` node and outputs
159164 trans_node = node .clone ()
160165 trans_node .op = new_op
161- trans_node .outputs [rv_var_out_idx ].name = node .outputs [rv_var_out_idx ].name
162166
163167 # We now assume that the old value variable represents the *transformed space*.
164168 # This means that we need to replace all instance of the old value variable
165169 # with "inversely/un-" transformed versions of itself.
166- new_value_var = transformed_variable (
167- transform .backward (value_var , * trans_node .inputs ), value_var
168- )
169- if value_var .name and getattr (transform , "name" , None ):
170- new_value_var .name = f"{ value_var .name } _{ transform .name } "
170+ for rv_var , value_var , transform in zip (rv_vars , value_vars , transforms ):
171+ rv_var_out_idx = node .outputs .index (rv_var )
172+ trans_node .outputs [rv_var_out_idx ].name = rv_var .name
171173
172- rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
174+ if transform is None :
175+ continue
176+
177+ new_value_var = transformed_variable (
178+ transform .backward (value_var , * trans_node .inputs ), value_var
179+ )
180+
181+ if value_var .name and getattr (transform , "name" , None ):
182+ new_value_var .name = f"{ value_var .name } _{ transform .name } "
183+
184+ rv_map_feature .update_rv_maps (rv_var , new_value_var , trans_node .outputs [rv_var_out_idx ])
173185
174186 return trans_node .outputs
175187
@@ -549,7 +561,7 @@ def log_jac_det(self, value, *inputs):
549561
550562def _create_transformed_rv_op (
551563 rv_op : Op ,
552- transform : RVTransform ,
564+ transforms : Union [ RVTransform , Sequence [ Union [ None , RVTransform ]]] ,
553565 * ,
554566 cls_dict_extra : Optional [Dict ] = None ,
555567) -> Op :
@@ -572,14 +584,20 @@ def _create_transformed_rv_op(
572584
573585 """
574586
575- trans_name = getattr (transform , "name" , "transformed" )
587+ if not isinstance (transforms , Sequence ):
588+ transforms = (transforms ,)
589+
590+ trans_names = [
591+ getattr (transform , "name" , "transformed" ) if transform is not None else "None"
592+ for transform in transforms
593+ ]
576594 rv_op_type = type (rv_op )
577595 rv_type_name = rv_op_type .__name__
578596 cls_dict = rv_op_type .__dict__ .copy ()
579597 rv_name = cls_dict .get ("name" , "" )
580598 if rv_name :
581- cls_dict ["name" ] = f"{ rv_name } _{ trans_name } "
582- cls_dict ["transform " ] = transform
599+ cls_dict ["name" ] = f"{ rv_name } _{ '_' . join ( trans_names ) } "
600+ cls_dict ["transforms " ] = transforms
583601
584602 if cls_dict_extra is not None :
585603 cls_dict .update (cls_dict_extra )
@@ -595,17 +613,27 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
595613 We assume that the value variable was back-transformed to be on the natural
596614 support of the respective random variable.
597615 """
598- ( value ,) = values
616+ logprobs = _logprob ( rv_op , values , * inputs , ** kwargs )
599617
600- logprob = _logprob (rv_op , values , * inputs , ** kwargs )
618+ if not isinstance (logprobs , Sequence ):
619+ logprobs = [logprobs ]
601620
602621 if use_jacobian :
603- assert isinstance (value .owner .op , TransformedVariable )
604- original_forward_value = value .owner .inputs [1 ]
605- jacobian = op .transform .log_jac_det (original_forward_value , * inputs )
606- logprob += jacobian
607-
608- return logprob
622+ assert len (values ) == len (logprobs ) == len (op .transforms )
623+ logprobs_jac = []
624+ for value , transform , logprob in zip (values , op .transforms , logprobs ):
625+ if transform is None :
626+ logprobs_jac .append (logprob )
627+ continue
628+ assert isinstance (value .owner .op , TransformedVariable )
629+ original_forward_value = value .owner .inputs [1 ]
630+ jacobian = transform .log_jac_det (original_forward_value , * inputs ).copy ()
631+ if value .name :
632+ jacobian .name = f"{ value .name } _jacobian"
633+ logprobs_jac .append (logprob + jacobian )
634+ logprobs = logprobs_jac
635+
636+ return logprobs
609637
610638 new_op = copy (rv_op )
611639 new_op .__class__ = new_op_type
0 commit comments