|
47 | 47 | from pymc.distributions.distribution import SymbolicRandomVariable |
48 | 48 | from pymc.distributions.transforms import Interval |
49 | 49 | from pymc.exceptions import NotConstantValueError |
| 50 | +from pymc.tests.helpers import assert_no_rvs |
50 | 51 | from pymc.vartypes import int_types |
51 | 52 |
|
52 | 53 |
|
@@ -632,3 +633,56 @@ def test_no_change_inplace(self): |
632 | 633 | after = aesara.clone_replace(m.free_RVs) |
633 | 634 |
|
634 | 635 | assert equal_computations(before, after) |
| 636 | + |
| 637 | + @pytest.mark.parametrize("reversed", (False, True)) |
| 638 | + def test_interdependent_transformed_rvs(self, reversed): |
| 639 | + # Test that nested transformed variables, whose transformed values depend on other |
| 640 | + # RVs are properly replaced |
| 641 | + with pm.Model() as m: |
| 642 | + transform = pm.distributions.transforms.Interval( |
| 643 | + bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) |
| 644 | + ) |
| 645 | + x = pm.Uniform("x", lower=0, upper=1, transform=transform) |
| 646 | + y = pm.Uniform("y", lower=0, upper=x, transform=transform) |
| 647 | + z = pm.Uniform("z", lower=0, upper=y, transform=transform) |
| 648 | + w = pm.Uniform("w", lower=0, upper=z, transform=transform) |
| 649 | + |
| 650 | + rvs = [x, y, z, w] |
| 651 | + if reversed: |
| 652 | + rvs = rvs[::-1] |
| 653 | + |
| 654 | + transform_values = rvs_to_value_vars(rvs) |
| 655 | + |
| 656 | + for transform_value in transform_values: |
| 657 | + assert_no_rvs(transform_value) |
| 658 | + |
| 659 | + if reversed: |
| 660 | + transform_values = transform_values[::-1] |
| 661 | + transform_values_fn = m.compile_fn(transform_values, point_fn=False) |
| 662 | + |
| 663 | + x_interval_test_value = np.random.rand() |
| 664 | + y_interval_test_value = np.random.rand() |
| 665 | + z_interval_test_value = np.random.rand() |
| 666 | + w_interval_test_value = np.random.rand() |
| 667 | + |
| 668 | + # The 3 Nones correspond to unused rng, dtype and size arguments |
| 669 | + expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval() |
| 670 | + expected_y = transform.backward( |
| 671 | + y_interval_test_value, None, None, None, 0, expected_x |
| 672 | + ).eval() |
| 673 | + expected_z = transform.backward( |
| 674 | + z_interval_test_value, None, None, None, 0, expected_y |
| 675 | + ).eval() |
| 676 | + expected_w = transform.backward( |
| 677 | + w_interval_test_value, None, None, None, 0, expected_z |
| 678 | + ).eval() |
| 679 | + |
| 680 | + np.testing.assert_allclose( |
| 681 | + transform_values_fn( |
| 682 | + x_interval__=x_interval_test_value, |
| 683 | + y_interval__=y_interval_test_value, |
| 684 | + z_interval__=z_interval_test_value, |
| 685 | + w_interval__=w_interval_test_value, |
| 686 | + ), |
| 687 | + [expected_x, expected_y, expected_z, expected_w], |
| 688 | + ) |
0 commit comments