@@ -500,7 +500,7 @@ def test_model_value_vars():
500500def test_model_var_maps ():
501501 with pm .Model () as model :
502502 a = pm .Uniform ("a" )
503- x = pm .Normal ("x" , a , total_size = 5 )
503+ x = pm .Normal ("x" , a )
504504
505505 assert set (model .rvs_to_values .keys ()) == {a , x }
506506 a_value = model .rvs_to_values [a ]
@@ -512,10 +512,7 @@ def test_model_var_maps():
512512 assert set (model .rvs_to_transforms .keys ()) == {a , x }
513513 assert isinstance (model .rvs_to_transforms [a ], IntervalTransform )
514514 assert model .rvs_to_transforms [x ] is None
515-
516- assert set (model .rvs_to_total_sizes .keys ()) == {a , x }
517- assert model .rvs_to_total_sizes [a ] is None
518- assert model .rvs_to_total_sizes [x ] == 5
515+ assert model .observed_rvs_to_total_sizes == {}
519516
520517
521518def test_make_obs_var ():
@@ -538,27 +535,28 @@ def test_make_obs_var():
538535 # Create the testval attribute simply for the sake of model testing
539536 fake_distribution .name = input_name
540537
538+ kwargs = dict (total_size = None , dims = None , transform = None )
541539 # The function requires data and RV dimensionality to be compatible
542540 with pytest .raises (ShapeError , match = "Dimensionality of data and RV don't match." ):
543- fake_model .make_obs_var (fake_distribution , np .ones ((3 , 3 , 1 )), None , None )
541+ fake_model .make_obs_var (fake_distribution , np .ones ((3 , 3 , 1 )), ** kwargs )
544542
545543 # Check function behavior using the various inputs
546544 # dense, sparse: Ensure that the missing values are appropriately set to None
547545 # masked: a deterministic variable is returned
548546
549- dense_output = fake_model .make_obs_var (fake_distribution , dense_input , None , None )
547+ dense_output = fake_model .make_obs_var (fake_distribution , dense_input , ** kwargs )
550548 assert dense_output == fake_distribution
551549 assert isinstance (fake_model .rvs_to_values [dense_output ], TensorConstant )
552550 del fake_model .named_vars [fake_distribution .name ]
553551
554- sparse_output = fake_model .make_obs_var (fake_distribution , sparse_input , None , None )
552+ sparse_output = fake_model .make_obs_var (fake_distribution , sparse_input , ** kwargs )
555553 assert sparse_output == fake_distribution
556554 assert sparse .basic ._is_sparse_variable (fake_model .rvs_to_values [sparse_output ])
557555 del fake_model .named_vars [fake_distribution .name ]
558556
559557 # Here the RandomVariable is split into observed/imputed and a Deterministic is returned
560558 with pytest .warns (ImputationWarning ):
561- masked_output = fake_model .make_obs_var (fake_distribution , masked_array_input , None , None )
559+ masked_output = fake_model .make_obs_var (fake_distribution , masked_array_input , ** kwargs )
562560 assert masked_output != fake_distribution
563561 assert not isinstance (masked_output , RandomVariable )
564562 # Ensure it has missing values
@@ -701,6 +699,29 @@ def test_set_initval():
701699 assert y in model .initial_values
702700
703701
702+ class TestTotalSize :
703+ def test_total_size_univariate (self ):
704+ with pm .Model () as m :
705+ x = pm .Normal ("x" , observed = [0 , 0 ], total_size = 7 )
706+ assert m .observed_rvs_to_total_sizes [x ] == 7
707+
708+ m .compile_logp ()({}) == st .norm ().logpdf (0 ) * 7
709+
710+ def test_total_size_multivariate (self ):
711+ with pm .Model () as m :
712+ x = pm .MvNormal ("x" , np .ones (3 ), np .eye (3 ), observed = np .zeros ((2 , 3 )), total_size = 7 )
713+ assert m .observed_rvs_to_total_sizes [x ] == 7
714+
715+ m .compile_logp ()({}) == st .multivariate_normal .logpdf (
716+ np .zeros (3 ), np .ones (3 ), np .eye (3 )
717+ ) * 7
718+
719+ def test_total_size_error (self ):
720+ with pm .Model ():
721+ with pytest .raises (ValueError , match = "total_size can only be used for observed RVs" ):
722+ pm .Normal ("x" , total_size = 7 )
723+
724+
704725def test_datalogp_multiple_shapes ():
705726 with pm .Model () as m :
706727 x = pm .Normal ("x" , 0 , 1 )
@@ -1425,6 +1446,18 @@ def test_error_non_random_variable(self):
14251446 observed = data ,
14261447 )
14271448
1449+ def test_rvs_to_total_sizes (self ):
1450+ with pm .Model () as m :
1451+ x = pm .Normal ("x" , observed = [np .nan , 0 , 1 ])
1452+ assert m ["x" ] not in m .observed_rvs_to_total_sizes
1453+ assert m ["x_missing" ] not in m .observed_rvs_to_total_sizes
1454+ assert m .observed_rvs_to_total_sizes [m ["x_observed" ]] is None
1455+
1456+ def test_total_size_not_supported (self ):
1457+ with pm .Model () as m :
1458+ with pytest .raises (NotImplementedError ):
1459+ x = pm .Normal ("x" , observed = [np .nan , 0 , 1 ], total_size = 5 )
1460+
14281461
14291462class TestShared (SeededTest ):
14301463 def test_deterministic (self ):
@@ -1467,16 +1500,15 @@ def test_tag_future_warning_model():
14671500 with pytest .raises (AttributeError ):
14681501 x .tag .observations
14691502
1470- with pytest .warns ( FutureWarning , match = "model.rvs_to_total_sizes" ):
1503+ with pytest .raises ( AttributeError ):
14711504 total_size = x .tag .total_size
1472- assert total_size is None
14731505
14741506 # Cloning a node will keep the same tag type and contents
14751507 y = x .owner .clone ().default_output ()
14761508 assert y is not x
14771509 assert y .tag is not x .tag
14781510 assert isinstance (y .tag , _FutureWarningValidatingScratchpad )
1479- y = model .register_rv (y , name = "y" , observed = 5 )
1511+ y = model .register_rv (y , name = "y" , observed = 5 , total_size = 7 )
14801512 assert isinstance (y .tag , _FutureWarningValidatingScratchpad )
14811513
14821514 # Test expected warnings
@@ -1486,8 +1518,7 @@ def test_tag_future_warning_model():
14861518 y_obs = y .tag .observations
14871519 assert y_value is y_obs
14881520 assert y_value .eval () == 5
1489-
1521+ with pytest .warns (FutureWarning , match = "model.observed_rvs_to_total_sizes" ):
1522+ y_total_size = y .tag .total_size
1523+ assert y_total_size == 7
14901524 assert isinstance (y_value .tag , _FutureWarningValidatingScratchpad )
1491- with pytest .warns (FutureWarning , match = "model.rvs_to_total_sizes" ):
1492- total_size = y .tag .total_size
1493- assert total_size is None
0 commit comments