Skip to content

Commit f2bc9ba

Browse files
michaelraczyckimichaelosthege
authored andcommitted
added test checking if variable not being in model.value_vars will trigger Value error
1 parent 933f524 commit f2bc9ba

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/sampling/test_mcmc.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
Metropolis,
4545
Slice,
4646
)
47+
from step_methods.compound import CompoundStep
4748
from tests.helpers import SeededTest, fast_unstable_sampling_mode
4849
from tests.models import simple_init
4950

@@ -817,6 +818,22 @@ def test_modify_step_methods(self):
817818
steps = assign_step_methods(model, [])
818819
assert isinstance(steps, NUTS)
819820

821+
def test_step_vars_in_model(self):
822+
"""Test if error is raised if step variable is not found in model.value_vars"""
823+
with pm.Model() as model:
824+
c1 = pm.HalfNormal("c1")
825+
c2 = pm.HalfNormal("c2")
826+
827+
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
828+
step1 = NUTS([c1])
829+
step2 = NUTS([c2])
830+
step2.vars = [
831+
at.dscalar("x"),
832+
]
833+
step = CompoundStep([step1, step2])
834+
with pytest.raises(ValueError):
835+
assign_step_methods(model, step)
836+
820837

821838
class TestType:
822839
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)

0 commit comments

Comments
 (0)