diff --git a/pymc/model.py b/pymc/model.py index efb50b32bc..156284f04d 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -556,8 +556,10 @@ def __init__( self.values_to_rvs = treedict(parent=self.parent.values_to_rvs) self.rvs_to_values = treedict(parent=self.parent.rvs_to_values) self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms) - self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes) self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values) + self.observed_rvs_to_total_sizes = treedict( + parent=self.parent.observed_rvs_to_total_sizes + ) self.free_RVs = treelist(parent=self.parent.free_RVs) self.observed_RVs = treelist(parent=self.parent.observed_RVs) self.deterministics = treelist(parent=self.parent.deterministics) @@ -570,8 +572,8 @@ def __init__( self.values_to_rvs = treedict() self.rvs_to_values = treedict() self.rvs_to_transforms = treedict() - self.rvs_to_total_sizes = treedict() self.rvs_to_initial_values = treedict() + self.observed_rvs_to_total_sizes = treedict() self.free_RVs = treelist() self.observed_RVs = treelist() self.deterministics = treelist() @@ -751,7 +753,7 @@ def logp( rvs=rvs, rvs_to_values=self.rvs_to_values, rvs_to_transforms=self.rvs_to_transforms, - rvs_to_total_sizes=self.rvs_to_total_sizes, + rvs_to_total_sizes=self.observed_rvs_to_total_sizes, jacobian=jacobian, ) assert isinstance(rv_logps, list) @@ -1289,8 +1291,6 @@ def register_rv( name = self.name_for(name) rv_var.name = name _add_future_warning_tag(rv_var) - rv_var.tag.total_size = total_size - self.rvs_to_total_sizes[rv_var] = total_size # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1300,6 +1300,8 @@ def register_rv( self.add_coord(dname, values=None, length=rv_var.shape[d]) if observed is None: + if total_size is not None: + raise ValueError("total_size can only be used for observed RVs") self.free_RVs.append(rv_var) self.create_value_var(rv_var, transform) self.add_named_variable(rv_var, dims) @@ -1323,12 +1325,20 @@ def register_rv( # `rv_var` is potentially changed by `make_obs_var`, # for example into a new graph for imputation of missing data. - rv_var = self.make_obs_var(rv_var, observed, dims, transform) + rv_var = self.make_obs_var( + rv_var, observed, total_size=total_size, dims=dims, transform=transform + ) return rv_var def make_obs_var( - self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any] + self, + rv_var: TensorVariable, + data: np.ndarray, + *, + total_size, + dims, + transform: Optional[Any], ) -> TensorVariable: """Create a `TensorVariable` for an observed random variable. @@ -1364,12 +1374,6 @@ def make_obs_var( mask = getattr(data, "mask", None) if mask is not None: - - if mask.all(): - # If there are no observed values, this variable isn't really - # observed. - return rv_var - impute_message = ( f"Data in {rv_var} contains missing values and" " will be automatically imputed from the" @@ -1377,6 +1381,9 @@ def make_obs_var( ) warnings.warn(impute_message, ImputationWarning) + if total_size is not None: + raise NotImplementedError("total_size cannot be used with imputation") + if not isinstance(rv_var.owner.op, RandomVariable): raise NotImplementedError( "Automatic inputation is only supported for univariate RandomVariables." @@ -1431,6 +1438,8 @@ def make_obs_var( self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) self.add_named_variable(observed_rv_var) self.observed_RVs.append(observed_rv_var) + self.observed_rvs_to_total_sizes[observed_rv_var] = None + observed_rv_var.tag.total_size = None # Create deterministic that combines observed and missing # Note: This can widely increase memory consumption during sampling for large datasets @@ -1448,6 +1457,8 @@ def make_obs_var( self.create_value_var(rv_var, transform=None, value_var=data) self.add_named_variable(rv_var, dims) self.observed_RVs.append(rv_var) + self.observed_rvs_to_total_sizes[rv_var] = total_size + rv_var.tag.total_size = total_size return rv_var diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index e03cd4507b..65c8aeb161 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -500,7 +500,7 @@ def test_model_value_vars(): def test_model_var_maps(): with pm.Model() as model: a = pm.Uniform("a") - x = pm.Normal("x", a, total_size=5) + x = pm.Normal("x", a) assert set(model.rvs_to_values.keys()) == {a, x} a_value = model.rvs_to_values[a] @@ -512,10 +512,7 @@ def test_model_var_maps(): assert set(model.rvs_to_transforms.keys()) == {a, x} assert isinstance(model.rvs_to_transforms[a], IntervalTransform) assert model.rvs_to_transforms[x] is None - - assert set(model.rvs_to_total_sizes.keys()) == {a, x} - assert model.rvs_to_total_sizes[a] is None - assert model.rvs_to_total_sizes[x] == 5 + assert model.observed_rvs_to_total_sizes == {} def test_make_obs_var(): @@ -538,27 +535,28 @@ def test_make_obs_var(): # Create the testval attribute simply for the sake of model testing fake_distribution.name = input_name + kwargs = dict(total_size=None, dims=None, transform=None) # The function requires data and RV dimensionality to be compatible with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."): - fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None) + fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), **kwargs) # Check function behavior using the various inputs # dense, sparse: Ensure that the missing values are appropriately set to None # masked: a deterministic variable is returned - dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None) + dense_output = fake_model.make_obs_var(fake_distribution, dense_input, **kwargs) assert dense_output == fake_distribution assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant) del fake_model.named_vars[fake_distribution.name] - sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None) + sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, **kwargs) assert sparse_output == fake_distribution assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output]) del fake_model.named_vars[fake_distribution.name] # Here the RandomVariable is split into observed/imputed and a Deterministic is returned with pytest.warns(ImputationWarning): - masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, None, None) + masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, **kwargs) assert masked_output != fake_distribution assert not isinstance(masked_output, RandomVariable) # Ensure it has missing values @@ -701,6 +699,29 @@ def test_set_initval(): assert y in model.initial_values +class TestTotalSize: + def test_total_size_univariate(self): + with pm.Model() as m: + x = pm.Normal("x", observed=[0, 0], total_size=7) + assert m.observed_rvs_to_total_sizes[x] == 7 + + m.compile_logp()({}) == st.norm().logpdf(0) * 7 + + def test_total_size_multivariate(self): + with pm.Model() as m: + x = pm.MvNormal("x", np.ones(3), np.eye(3), observed=np.zeros((2, 3)), total_size=7) + assert m.observed_rvs_to_total_sizes[x] == 7 + + m.compile_logp()({}) == st.multivariate_normal.logpdf( + np.zeros(3), np.ones(3), np.eye(3) + ) * 7 + + def test_total_size_error(self): + with pm.Model(): + with pytest.raises(ValueError, match="total_size can only be used for observed RVs"): + pm.Normal("x", total_size=7) + + def test_datalogp_multiple_shapes(): with pm.Model() as m: x = pm.Normal("x", 0, 1) @@ -1425,6 +1446,18 @@ def test_error_non_random_variable(self): observed=data, ) + def test_rvs_to_total_sizes(self): + with pm.Model() as m: + x = pm.Normal("x", observed=[np.nan, 0, 1]) + assert m["x"] not in m.observed_rvs_to_total_sizes + assert m["x_missing"] not in m.observed_rvs_to_total_sizes + assert m.observed_rvs_to_total_sizes[m["x_observed"]] is None + + def test_total_size_not_supported(self): + with pm.Model() as m: + with pytest.raises(NotImplementedError): + x = pm.Normal("x", observed=[np.nan, 0, 1], total_size=5) + class TestShared(SeededTest): def test_deterministic(self): @@ -1467,16 +1500,15 @@ def test_tag_future_warning_model(): with pytest.raises(AttributeError): x.tag.observations - with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): + with pytest.raises(AttributeError): total_size = x.tag.total_size - assert total_size is None # Cloning a node will keep the same tag type and contents y = x.owner.clone().default_output() assert y is not x assert y.tag is not x.tag assert isinstance(y.tag, _FutureWarningValidatingScratchpad) - y = model.register_rv(y, name="y", observed=5) + y = model.register_rv(y, name="y", observed=5, total_size=7) assert isinstance(y.tag, _FutureWarningValidatingScratchpad) # Test expected warnings @@ -1486,8 +1518,7 @@ def test_tag_future_warning_model(): y_obs = y.tag.observations assert y_value is y_obs assert y_value.eval() == 5 - + with pytest.warns(FutureWarning, match="model.observed_rvs_to_total_sizes"): + y_total_size = y.tag.total_size + assert y_total_size == 7 assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad) - with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): - total_size = y.tag.total_size - assert total_size is None diff --git a/pymc/util.py b/pymc/util.py index cbeceb34c5..4de18f94ae 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -490,7 +490,7 @@ def __getattribute__(self, name): for deprecated_names, alternative in ( (("value_var", "observations"), "model.rvs_to_values[rv]"), (("transform",), "model.rvs_to_transforms[rv]"), - (("total_size",), "model.rvs_to_total_sizes[rv]"), + (("total_size",), "model.observed_rvs_to_total_sizes[rv]"), ): if name in deprecated_names: try: diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index f0c854b66f..b4df9a72b0 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1046,7 +1046,9 @@ def symbolic_normalizing_constant(self): t = self.to_flat_input( at.max( [ - _get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim) + _get_scaling( + self.model.observed_rvs_to_total_sizes.get(v, None), v.shape, v.ndim + ) for v in self.group ] ) @@ -1184,7 +1186,7 @@ def symbolic_normalizing_constant(self): self.collect("symbolic_normalizing_constant") + [ _get_scaling( - self.model.rvs_to_total_sizes.get(obs, None), + self.model.observed_rvs_to_total_sizes.get(obs, None), obs.shape, obs.ndim, )