|
32 | 32 | import pymc as pm |
33 | 33 |
|
34 | 34 | from pymc.aesaraf import compile_rv_inplace |
| 35 | +from pymc.backends.base import MultiTrace |
35 | 36 | from pymc.backends.ndarray import NDArray |
36 | 37 | from pymc.exceptions import IncorrectArgumentsError, SamplingError |
37 | 38 | from pymc.tests.helpers import SeededTest |
@@ -438,14 +439,29 @@ def test_constant_named(self): |
438 | 439 | class TestChooseBackend: |
439 | 440 | def test_choose_backend_none(self): |
440 | 441 | with mock.patch("pymc.sampling.NDArray") as nd: |
441 | | - pm.sampling._choose_backend(None, "chain") |
| 442 | + pm.sampling._choose_backend(None) |
442 | 443 | assert nd.called |
443 | 444 |
|
444 | 445 | def test_choose_backend_list_of_variables(self): |
445 | 446 | with mock.patch("pymc.sampling.NDArray") as nd: |
446 | | - pm.sampling._choose_backend(["var1", "var2"], "chain") |
| 447 | + pm.sampling._choose_backend(["var1", "var2"]) |
447 | 448 | nd.assert_called_with(vars=["var1", "var2"]) |
448 | 449 |
|
| 450 | + def test_errors_and_warnings(self): |
| 451 | + with pm.Model(): |
| 452 | + A = pm.Normal("A") |
| 453 | + B = pm.Uniform("B") |
| 454 | + strace = pm.sampling.NDArray(vars=[A, B]) |
| 455 | + strace.setup(10, 0) |
| 456 | + |
| 457 | + with pytest.raises(ValueError, match="from existing MultiTrace"): |
| 458 | + pm.sampling._choose_backend(trace=MultiTrace([strace])) |
| 459 | + |
| 460 | + strace.record({"A": 2, "B_interval__": 0.1}) |
| 461 | + assert len(strace) == 1 |
| 462 | + with pytest.raises(ValueError, match="Continuation of traces"): |
| 463 | + pm.sampling._choose_backend(trace=strace) |
| 464 | + |
449 | 465 |
|
450 | 466 | class TestSamplePPC(SeededTest): |
451 | 467 | def test_normal_scalar(self): |
|
0 commit comments