@@ -596,91 +596,88 @@ def test_rvs_to_value_vars_nested():
596596 assert equal_computations (before , after )
597597
598598
599- def test_check_bounds_flag ():
600- """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
601- logp = at .ones (3 )
602- cond = np .array ([1 , 0 , 1 ])
603- bound = check_parameters (logp , cond )
604-
605- with pm .Model () as m :
606- pass
607-
608- with pytest .raises (ParameterValueError ):
609- aesara .function ([], bound )()
610-
611- m .check_bounds = False
612- with m :
613- assert np .all (compile_pymc ([], bound )() == 1 )
614-
615- m .check_bounds = True
616- with m :
617- assert np .all (compile_pymc ([], bound )() == - np .inf )
618-
619-
620- def test_compile_pymc_sets_rng_updates ():
621- rng = aesara .shared (np .random .default_rng (0 ))
622- x = pm .Normal .dist (rng = rng )
623- assert x .owner .inputs [0 ] is rng
624- f = compile_pymc ([], x )
625- assert not np .isclose (f (), f ())
626-
627- # Check that update was not done inplace
628- assert not hasattr (rng , "default_update" )
629- f = aesara .function ([], x )
630- assert f () == f ()
631-
632-
633- def test_compile_pymc_with_updates ():
634- x = aesara .shared (0 )
635- f = compile_pymc ([], x , updates = {x : x + 1 })
636- assert f () == 0
637- assert f () == 1
638-
639-
640- def test_compile_pymc_missing_default_explicit_updates ():
641- rng = aesara .shared (np .random .default_rng (0 ))
642- x = pm .Normal .dist (rng = rng )
643-
644- # By default, compile_pymc should update the rng of x
645- f = compile_pymc ([], x )
646- assert f () != f ()
647-
648- # An explicit update should override the default_update, like aesara.function does
649- # For testing purposes, we use an update that leaves the rng unchanged
650- f = compile_pymc ([], x , updates = {rng : rng })
651- assert f () == f ()
652-
653- # If we specify a custom default_update directly it should use that instead.
654- rng .default_update = rng
655- f = compile_pymc ([], x )
656- assert f () == f ()
657-
658- # And again, it should be overridden by an explicit update
659- f = compile_pymc ([], x , updates = {rng : x .owner .outputs [0 ]})
660- assert f () != f ()
661-
662-
663- def test_compile_pymc_updates_inputs ():
664- """Test that compile_pymc does not include rngs updates of variables that are inputs
665- or ancestors to inputs
666- """
667- x = at .random .normal ()
668- y = at .random .normal (x )
669- z = at .random .normal (y )
670-
671- for inputs , rvs_in_graph in (
672- ([], 3 ),
673- ([x ], 2 ),
674- ([y ], 1 ),
675- ([z ], 0 ),
676- ([x , y ], 1 ),
677- ([x , y , z ], 0 ),
678- ):
679- fn = compile_pymc (inputs , z , on_unused_input = "ignore" )
680- fn_fgraph = fn .maker .fgraph
681- # Each RV adds a shared input for its rng
682- assert len (fn_fgraph .inputs ) == len (inputs ) + rvs_in_graph
683- # If the output is an input, the graph has a DeepCopyOp
684- assert len (fn_fgraph .apply_nodes ) == max (rvs_in_graph , 1 )
685- # Each RV adds a shared output for its rng
686- assert len (fn_fgraph .outputs ) == 1 + rvs_in_graph
599+ class TestCompilePyMC :
600+ def test_check_bounds_flag (self ):
601+ """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
602+ logp = at .ones (3 )
603+ cond = np .array ([1 , 0 , 1 ])
604+ bound = check_parameters (logp , cond )
605+
606+ with pm .Model () as m :
607+ pass
608+
609+ with pytest .raises (ParameterValueError ):
610+ aesara .function ([], bound )()
611+
612+ m .check_bounds = False
613+ with m :
614+ assert np .all (compile_pymc ([], bound )() == 1 )
615+
616+ m .check_bounds = True
617+ with m :
618+ assert np .all (compile_pymc ([], bound )() == - np .inf )
619+
620+ def test_compile_pymc_sets_rng_updates (self ):
621+ rng = aesara .shared (np .random .default_rng (0 ))
622+ x = pm .Normal .dist (rng = rng )
623+ assert x .owner .inputs [0 ] is rng
624+ f = compile_pymc ([], x )
625+ assert not np .isclose (f (), f ())
626+
627+ # Check that update was not done inplace
628+ assert not hasattr (rng , "default_update" )
629+ f = aesara .function ([], x )
630+ assert f () == f ()
631+
632+ def test_compile_pymc_with_updates (self ):
633+ x = aesara .shared (0 )
634+ f = compile_pymc ([], x , updates = {x : x + 1 })
635+ assert f () == 0
636+ assert f () == 1
637+
638+ def test_compile_pymc_missing_default_explicit_updates (self ):
639+ rng = aesara .shared (np .random .default_rng (0 ))
640+ x = pm .Normal .dist (rng = rng )
641+
642+ # By default, compile_pymc should update the rng of x
643+ f = compile_pymc ([], x )
644+ assert f () != f ()
645+
646+ # An explicit update should override the default_update, like aesara.function does
647+ # For testing purposes, we use an update that leaves the rng unchanged
648+ f = compile_pymc ([], x , updates = {rng : rng })
649+ assert f () == f ()
650+
651+ # If we specify a custom default_update directly it should use that instead.
652+ rng .default_update = rng
653+ f = compile_pymc ([], x )
654+ assert f () == f ()
655+
656+ # And again, it should be overridden by an explicit update
657+ f = compile_pymc ([], x , updates = {rng : x .owner .outputs [0 ]})
658+ assert f () != f ()
659+
660+ def test_compile_pymc_updates_inputs (self ):
661+ """Test that compile_pymc does not include rngs updates of variables that are inputs
662+ or ancestors to inputs
663+ """
664+ x = at .random .normal ()
665+ y = at .random .normal (x )
666+ z = at .random .normal (y )
667+
668+ for inputs , rvs_in_graph in (
669+ ([], 3 ),
670+ ([x ], 2 ),
671+ ([y ], 1 ),
672+ ([z ], 0 ),
673+ ([x , y ], 1 ),
674+ ([x , y , z ], 0 ),
675+ ):
676+ fn = compile_pymc (inputs , z , on_unused_input = "ignore" )
677+ fn_fgraph = fn .maker .fgraph
678+ # Each RV adds a shared input for its rng
679+ assert len (fn_fgraph .inputs ) == len (inputs ) + rvs_in_graph
680+ # If the output is an input, the graph has a DeepCopyOp
681+ assert len (fn_fgraph .apply_nodes ) == max (rvs_in_graph , 1 )
682+ # Each RV adds a shared output for its rng
683+ assert len (fn_fgraph .outputs ) == 1 + rvs_in_graph
0 commit comments