File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change 4444 Metropolis ,
4545 Slice ,
4646)
47+ from step_methods .compound import CompoundStep
4748from tests .helpers import SeededTest , fast_unstable_sampling_mode
4849from tests .models import simple_init
4950
@@ -817,6 +818,22 @@ def test_modify_step_methods(self):
817818 steps = assign_step_methods (model , [])
818819 assert isinstance (steps , NUTS )
819820
821+ def test_step_vars_in_model (self ):
822+ """Test if error is raised if step variable is not found in model.value_vars"""
823+ with pm .Model () as model :
824+ c1 = pm .HalfNormal ("c1" )
825+ c2 = pm .HalfNormal ("c2" )
826+
827+ with pytensor .config .change_flags (mode = fast_unstable_sampling_mode ):
828+ step1 = NUTS ([c1 ])
829+ step2 = NUTS ([c2 ])
830+ step2 .vars = [
831+ at .dscalar ("x" ),
832+ ]
833+ step = CompoundStep ([step1 , step2 ])
834+ with pytest .raises (ValueError ):
835+ assign_step_methods (model , step )
836+
820837
821838class TestType :
822839 samplers = (Metropolis , Slice , HamiltonianMC , NUTS )
You can’t perform that action at this time.
0 commit comments