Skip to content

Commit 73c0c91

Browse files
twieckimichaelosthege
authored andcommitted
Merge conflict.
Add test that was erroneously removed.
1 parent edc99b6 commit 73c0c91

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

pymc/tests/test_sampling.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import pymc as pm
3333

3434
from pymc.aesaraf import compile_rv_inplace
35+
from pymc.backends.base import MultiTrace
3536
from pymc.backends.ndarray import NDArray
3637
from pymc.exceptions import IncorrectArgumentsError, SamplingError
3738
from pymc.tests.helpers import SeededTest
@@ -438,14 +439,29 @@ def test_constant_named(self):
438439
class TestChooseBackend:
439440
def test_choose_backend_none(self):
440441
with mock.patch("pymc.sampling.NDArray") as nd:
441-
pm.sampling._choose_backend(None, "chain")
442+
pm.sampling._choose_backend(None)
442443
assert nd.called
443444

444445
def test_choose_backend_list_of_variables(self):
445446
with mock.patch("pymc.sampling.NDArray") as nd:
446-
pm.sampling._choose_backend(["var1", "var2"], "chain")
447+
pm.sampling._choose_backend(["var1", "var2"])
447448
nd.assert_called_with(vars=["var1", "var2"])
448449

450+
def test_errors_and_warnings(self):
451+
with pm.Model():
452+
A = pm.Normal("A")
453+
B = pm.Uniform("B")
454+
strace = pm.sampling.NDArray(vars=[A, B])
455+
strace.setup(10, 0)
456+
457+
with pytest.raises(ValueError, match="from existing MultiTrace"):
458+
pm.sampling._choose_backend(trace=MultiTrace([strace]))
459+
460+
strace.record({"A": 2, "B_interval__": 0.1})
461+
assert len(strace) == 1
462+
with pytest.raises(ValueError, match="Continuation of traces"):
463+
pm.sampling._choose_backend(trace=strace)
464+
449465

450466
class TestSamplePPC(SeededTest):
451467
def test_normal_scalar(self):

0 commit comments

Comments
 (0)