@@ -85,7 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
8585 normal_lcdf ,
8686 zvalue ,
8787)
88- from pymc .distributions .distribution import Continuous
88+ from pymc .distributions .distribution import DIST_PARAMETER_TYPES , Continuous
8989from pymc .distributions .shape_utils import rv_size_is_none
9090from pymc .math import invlogit , logdiffexp , logit
9191from pymc .util import UNSET
@@ -692,12 +692,12 @@ class TruncatedNormal(BoundedContinuous):
692692 @classmethod
693693 def dist (
694694 cls ,
695- mu : Optional [Union [ float , np . ndarray ] ] = None ,
696- sigma : Optional [Union [ float , np . ndarray ] ] = None ,
697- tau : Optional [Union [ float , np . ndarray ] ] = None ,
698- sd : Optional [Union [ float , np . ndarray ] ] = None ,
699- lower : Optional [Union [ float , np . ndarray ] ] = None ,
700- upper : Optional [Union [ float , np . ndarray ] ] = None ,
695+ mu : Optional [DIST_PARAMETER_TYPES ] = None ,
696+ sigma : Optional [DIST_PARAMETER_TYPES ] = None ,
697+ tau : Optional [DIST_PARAMETER_TYPES ] = None ,
698+ sd : Optional [DIST_PARAMETER_TYPES ] = None ,
699+ lower : Optional [DIST_PARAMETER_TYPES ] = None ,
700+ upper : Optional [DIST_PARAMETER_TYPES ] = None ,
701701 transform : str = "auto" ,
702702 * args ,
703703 ** kwargs ,
0 commit comments