@@ -526,6 +526,7 @@ def check_icdf(
526526 pymc_dist : Distribution ,
527527 paramdomains : Dict [str , Domain ],
528528 scipy_icdf : Callable ,
529+ skip_paramdomain_outside_edge_test = False ,
529530 decimal : Optional [int ] = None ,
530531 n_samples : int = 100 ,
531532) -> None :
@@ -548,7 +549,7 @@ def check_icdf(
548549 paramdomains : Dictionary of Parameter : Domain pairs
549550 Supported domains of distribution parameters
550551 scipy_icdf : Scipy icdf method
551- Scipy icdf (ppp ) method of equivalent pymc_dist distribution
552+ Scipy icdf (ppf ) method of equivalent pymc_dist distribution
552553 decimal : int, optional
553554 Level of precision with which pymc_dist and scipy_icdf are compared.
554555 Defaults to 6 for float64 and 3 for float32
@@ -557,6 +558,9 @@ def check_icdf(
557558 are compared between pymc and scipy methods. If n_samples is below the
558559 total number of combinations, a random subset is evaluated. Setting
559560 n_samples = -1, will return all possible combinations. Defaults to 100
561+ skip_paradomain_outside_edge_test : Bool
562+ Whether to run test 2., which checks that pymc distribution icdf
563+ returns nan for invalid parameter values outside the supported domain edge
560564
561565 """
562566 if decimal is None :
@@ -586,19 +590,20 @@ def check_icdf(
586590 valid_params = {param : paramdomain .vals [0 ] for param , paramdomain in paramdomains .items ()}
587591 valid_params ["q" ] = valid_value
588592
589- # Test pymc distribution raises ParameterValueError for parameters outside the
590- # supported domain edges (excluding edges)
591- invalid_params = find_invalid_scalar_params (paramdomains )
592- for invalid_param , invalid_edges in invalid_params .items ():
593- for invalid_edge in invalid_edges :
594- if invalid_edge is None :
595- continue
593+ if not skip_paramdomain_outside_edge_test :
594+ # Test pymc distribution raises ParameterValueError for parameters outside the
595+ # supported domain edges (excluding edges)
596+ invalid_params = find_invalid_scalar_params (paramdomains )
597+ for invalid_param , invalid_edges in invalid_params .items ():
598+ for invalid_edge in invalid_edges :
599+ if invalid_edge is None :
600+ continue
596601
597- point = valid_params .copy ()
598- point [invalid_param ] = invalid_edge
599- with pytest .raises (ParameterValueError ):
600- pymc_icdf (** point )
601- pytest .fail (f"test_params={ point } " )
602+ point = valid_params .copy ()
603+ point [invalid_param ] = invalid_edge
604+ with pytest .raises (ParameterValueError ):
605+ pymc_icdf (** point )
606+ pytest .fail (f"test_params={ point } " )
602607
603608 # Test that values below 0 or above 1 evaluate to nan
604609 invalid_values = find_invalid_scalar_params ({"q" : domain })["q" ]
0 commit comments