3737import abc
3838
3939from copy import copy
40- from functools import partial , singledispatch
4140from typing import Callable , Dict , List , Optional , Tuple , Union
4241
4342import pytensor .tensor as at
6968from pymc .logprob .utils import walk_model
7069
7170
72- @singledispatch
73- def _default_transformed_rv (
74- op : Op ,
75- node : Node ,
76- ) -> Optional [Apply ]:
77- """Create a node for a transformed log-probability of a `MeasurableVariable`.
78-
79- This function dispatches on the type of `op`. If you want to implement
80- new transforms for a `MeasurableVariable`, register a function on this
81- dispatcher.
82-
83- """
84- return None
85-
86-
8771class TransformedVariable (Op ):
8872 """A no-op that identifies a transform and its un-transformed input."""
8973
@@ -136,13 +120,6 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
136120 return at .log (at .abs (at .nlinalg .det (at .atleast_2d (jacobian (phi_inv , [value ])[0 ]))))
137121
138122
139- class DefaultTransformSentinel :
140- pass
141-
142-
143- DEFAULT_TRANSFORM = DefaultTransformSentinel ()
144-
145-
146123@node_rewriter (tracks = None )
147124def transform_values (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
148125 """Apply transforms to value variables.
@@ -176,17 +153,12 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
176153
177154 if transform is None :
178155 return None
179- elif transform is DEFAULT_TRANSFORM :
180- trans_node = _default_transformed_rv (node .op , node )
181- if trans_node is None :
182- return None
183- transform = trans_node .op .transform
184- else :
185- new_op = _create_transformed_rv_op (node .op , transform )
186- # Create a new `Apply` node and outputs
187- trans_node = node .clone ()
188- trans_node .op = new_op
189- trans_node .outputs [rv_var_out_idx ].name = node .outputs [rv_var_out_idx ].name
156+
157+ new_op = _create_transformed_rv_op (node .op , transform )
158+ # Create a new `Apply` node and outputs
159+ trans_node = node .clone ()
160+ trans_node .op = new_op
161+ trans_node .outputs [rv_var_out_idx ].name = node .outputs [rv_var_out_idx ].name
190162
191163 # We now assume that the old value variable represents the *transformed space*.
192164 # This means that we need to replace all instance of the old value variable
@@ -216,24 +188,22 @@ def on_attach(self, fgraph):
216188
217189
218190class TransformValuesRewrite (GraphRewriter ):
219- r"""Transforms value variables according to a map and/or per-`RandomVariable` defaults ."""
191+ r"""Transforms value variables according to a map."""
220192
221- default_transform_rewrite = in2out (transform_values , ignore_newtrees = True )
193+ transform_rewrite = in2out (transform_values , ignore_newtrees = True )
222194
223195 def __init__ (
224196 self ,
225- values_to_transforms : Dict [
226- TensorVariable , Union [RVTransform , DefaultTransformSentinel , None ]
227- ],
197+ values_to_transforms : Dict [TensorVariable , Union [RVTransform , None ]],
228198 ):
229199 """
230200 Parameters
231201 ==========
232202 values_to_transforms
233203 Mapping between value variables and their transformations. Each
234- value variable can be assigned one of `RVTransform`,
235- ``DEFAULT_TRANSFORM``, or ``None``. If a transform is not specified
236- for a specific value variable it will not be transformed.
204+ value variable can be assigned one of `RVTransform`, or ``None``.
205+ If a transform is not specified for a specific value variable it will
206+ not be transformed.
237207
238208 """
239209
@@ -244,7 +214,7 @@ def add_requirements(self, fgraph):
244214 fgraph .attach_feature (values_transforms_feature )
245215
246216 def apply (self , fgraph : FunctionGraph ):
247- return self .default_transform_rewrite .rewrite (fgraph )
217+ return self .transform_rewrite .rewrite (fgraph )
248218
249219
250220class MeasurableTransform (MeasurableElemwise ):
@@ -583,7 +553,6 @@ def _create_transformed_rv_op(
583553 rv_op : Op ,
584554 transform : RVTransform ,
585555 * ,
586- default : bool = False ,
587556 cls_dict_extra : Optional [Dict ] = None ,
588557) -> Op :
589558 """Create a new transformed variable instance given a base `RandomVariable` `Op`.
@@ -600,8 +569,6 @@ def _create_transformed_rv_op(
600569 The `RandomVariable` for which we want to construct a `TransformedRV`.
601570 transform
602571 The `RVTransform` for `rv_op`.
603- default
604- If ``False`` do not make `transform` the default transform for `rv_op`.
605572 cls_dict_extra
606573 Additional class members to add to the constructed `TransformedRV`.
607574
@@ -642,85 +609,7 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
642609
643610 return logprob
644611
645- transform_op = rv_op_type if default else new_op_type
646-
647- @_default_transformed_rv .register (transform_op )
648- def class_transformed_rv (op , node ):
649- new_op = new_op_type ()
650- res = new_op .make_node (* node .inputs )
651- res .outputs [1 ].name = node .outputs [1 ].name
652- return res
653-
654612 new_op = copy (rv_op )
655613 new_op .__class__ = new_op_type
656614
657615 return new_op
658-
659-
660- create_default_transformed_rv_op = partial (_create_transformed_rv_op , default = True )
661-
662-
663- TransformedUniformRV = create_default_transformed_rv_op (
664- at .random .uniform ,
665- # inputs[3] = lower; inputs[4] = upper
666- IntervalTransform (lambda * inputs : (inputs [3 ], inputs [4 ])),
667- )
668- TransformedParetoRV = create_default_transformed_rv_op (
669- at .random .pareto ,
670- # inputs[3] = alpha
671- IntervalTransform (lambda * inputs : (inputs [3 ], None )),
672- )
673- TransformedTriangularRV = create_default_transformed_rv_op (
674- at .random .triangular ,
675- # inputs[3] = lower; inputs[5] = upper
676- IntervalTransform (lambda * inputs : (inputs [3 ], inputs [5 ])),
677- )
678- TransformedHalfNormalRV = create_default_transformed_rv_op (
679- at .random .halfnormal ,
680- # inputs[3] = loc
681- IntervalTransform (lambda * inputs : (inputs [3 ], None )),
682- )
683- TransformedWaldRV = create_default_transformed_rv_op (
684- at .random .wald ,
685- LogTransform (),
686- )
687- TransformedExponentialRV = create_default_transformed_rv_op (
688- at .random .exponential ,
689- LogTransform (),
690- )
691- TransformedLognormalRV = create_default_transformed_rv_op (
692- at .random .lognormal ,
693- LogTransform (),
694- )
695- TransformedHalfCauchyRV = create_default_transformed_rv_op (
696- at .random .halfcauchy ,
697- LogTransform (),
698- )
699- TransformedGammaRV = create_default_transformed_rv_op (
700- at .random .gamma ,
701- LogTransform (),
702- )
703- TransformedInvGammaRV = create_default_transformed_rv_op (
704- at .random .invgamma ,
705- LogTransform (),
706- )
707- TransformedChiSquareRV = create_default_transformed_rv_op (
708- at .random .chisquare ,
709- LogTransform (),
710- )
711- TransformedWeibullRV = create_default_transformed_rv_op (
712- at .random .weibull ,
713- LogTransform (),
714- )
715- TransformedBetaRV = create_default_transformed_rv_op (
716- at .random .beta ,
717- LogOddsTransform (),
718- )
719- TransformedVonMisesRV = create_default_transformed_rv_op (
720- at .random .vonmises ,
721- CircularTransform (),
722- )
723- TransformedDirichletRV = create_default_transformed_rv_op (
724- at .random .dirichlet ,
725- SimplexTransform (),
726- )
0 commit comments