Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,10 @@ def __init__(
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
self.observed_rvs_to_total_sizes = treedict(
parent=self.parent.observed_rvs_to_total_sizes
)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.deterministics = treelist(parent=self.parent.deterministics)
Expand All @@ -570,8 +572,8 @@ def __init__(
self.values_to_rvs = treedict()
self.rvs_to_values = treedict()
self.rvs_to_transforms = treedict()
self.rvs_to_total_sizes = treedict()
self.rvs_to_initial_values = treedict()
self.observed_rvs_to_total_sizes = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
self.deterministics = treelist()
Expand Down Expand Up @@ -751,7 +753,7 @@ def logp(
rvs=rvs,
rvs_to_values=self.rvs_to_values,
rvs_to_transforms=self.rvs_to_transforms,
rvs_to_total_sizes=self.rvs_to_total_sizes,
rvs_to_total_sizes=self.observed_rvs_to_total_sizes,
jacobian=jacobian,
)
assert isinstance(rv_logps, list)
Expand Down Expand Up @@ -1289,8 +1291,6 @@ def register_rv(
name = self.name_for(name)
rv_var.name = name
_add_future_warning_tag(rv_var)
rv_var.tag.total_size = total_size
self.rvs_to_total_sizes[rv_var] = total_size

# Associate previously unknown dimension names with
# the length of the corresponding RV dimension.
Expand All @@ -1300,6 +1300,8 @@ def register_rv(
self.add_coord(dname, values=None, length=rv_var.shape[d])

if observed is None:
if total_size is not None:
raise ValueError("total_size can only be used for observed RVs")
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.add_named_variable(rv_var, dims)
Expand All @@ -1323,12 +1325,20 @@ def register_rv(

# `rv_var` is potentially changed by `make_obs_var`,
# for example into a new graph for imputation of missing data.
rv_var = self.make_obs_var(rv_var, observed, dims, transform)
rv_var = self.make_obs_var(
rv_var, observed, total_size=total_size, dims=dims, transform=transform
)

return rv_var

def make_obs_var(
self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any]
self,
rv_var: TensorVariable,
data: np.ndarray,
*,
total_size,
dims,
transform: Optional[Any],
) -> TensorVariable:
"""Create a `TensorVariable` for an observed random variable.

Expand Down Expand Up @@ -1364,19 +1374,16 @@ def make_obs_var(

mask = getattr(data, "mask", None)
if mask is not None:

if mask.all():
# If there are no observed values, this variable isn't really
# observed.
return rv_var

impute_message = (
f"Data in {rv_var} contains missing values and"
" will be automatically imputed from the"
" sampling distribution."
)
warnings.warn(impute_message, ImputationWarning)

if total_size is not None:
raise NotImplementedError("total_size cannot be used with imputation")

if not isinstance(rv_var.owner.op, RandomVariable):
raise NotImplementedError(
"Automatic inputation is only supported for univariate RandomVariables."
Expand Down Expand Up @@ -1431,6 +1438,8 @@ def make_obs_var(
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
self.add_named_variable(observed_rv_var)
self.observed_RVs.append(observed_rv_var)
self.observed_rvs_to_total_sizes[observed_rv_var] = None
observed_rv_var.tag.total_size = None

# Create deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
Expand All @@ -1448,6 +1457,8 @@ def make_obs_var(
self.create_value_var(rv_var, transform=None, value_var=data)
self.add_named_variable(rv_var, dims)
self.observed_RVs.append(rv_var)
self.observed_rvs_to_total_sizes[rv_var] = total_size
rv_var.tag.total_size = total_size

return rv_var

Expand Down
63 changes: 47 additions & 16 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def test_model_value_vars():
def test_model_var_maps():
with pm.Model() as model:
a = pm.Uniform("a")
x = pm.Normal("x", a, total_size=5)
x = pm.Normal("x", a)

assert set(model.rvs_to_values.keys()) == {a, x}
a_value = model.rvs_to_values[a]
Expand All @@ -512,10 +512,7 @@ def test_model_var_maps():
assert set(model.rvs_to_transforms.keys()) == {a, x}
assert isinstance(model.rvs_to_transforms[a], IntervalTransform)
assert model.rvs_to_transforms[x] is None

assert set(model.rvs_to_total_sizes.keys()) == {a, x}
assert model.rvs_to_total_sizes[a] is None
assert model.rvs_to_total_sizes[x] == 5
assert model.observed_rvs_to_total_sizes == {}


def test_make_obs_var():
Expand All @@ -538,27 +535,28 @@ def test_make_obs_var():
# Create the testval attribute simply for the sake of model testing
fake_distribution.name = input_name

kwargs = dict(total_size=None, dims=None, transform=None)
# The function requires data and RV dimensionality to be compatible
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None)
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), **kwargs)

# Check function behavior using the various inputs
# dense, sparse: Ensure that the missing values are appropriately set to None
# masked: a deterministic variable is returned

dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None)
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, **kwargs)
assert dense_output == fake_distribution
assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant)
del fake_model.named_vars[fake_distribution.name]

sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None)
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, **kwargs)
assert sparse_output == fake_distribution
assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output])
del fake_model.named_vars[fake_distribution.name]

# Here the RandomVariable is split into observed/imputed and a Deterministic is returned
with pytest.warns(ImputationWarning):
masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, None, None)
masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, **kwargs)
assert masked_output != fake_distribution
assert not isinstance(masked_output, RandomVariable)
# Ensure it has missing values
Expand Down Expand Up @@ -701,6 +699,29 @@ def test_set_initval():
assert y in model.initial_values


class TestTotalSize:
def test_total_size_univariate(self):
with pm.Model() as m:
x = pm.Normal("x", observed=[0, 0], total_size=7)
assert m.observed_rvs_to_total_sizes[x] == 7

m.compile_logp()({}) == st.norm().logpdf(0) * 7

def test_total_size_multivariate(self):
with pm.Model() as m:
x = pm.MvNormal("x", np.ones(3), np.eye(3), observed=np.zeros((2, 3)), total_size=7)
assert m.observed_rvs_to_total_sizes[x] == 7

m.compile_logp()({}) == st.multivariate_normal.logpdf(
np.zeros(3), np.ones(3), np.eye(3)
) * 7

def test_total_size_error(self):
with pm.Model():
with pytest.raises(ValueError, match="total_size can only be used for observed RVs"):
pm.Normal("x", total_size=7)


def test_datalogp_multiple_shapes():
with pm.Model() as m:
x = pm.Normal("x", 0, 1)
Expand Down Expand Up @@ -1425,6 +1446,18 @@ def test_error_non_random_variable(self):
observed=data,
)

def test_rvs_to_total_sizes(self):
with pm.Model() as m:
x = pm.Normal("x", observed=[np.nan, 0, 1])
assert m["x"] not in m.observed_rvs_to_total_sizes
assert m["x_missing"] not in m.observed_rvs_to_total_sizes
assert m.observed_rvs_to_total_sizes[m["x_observed"]] is None

def test_total_size_not_supported(self):
with pm.Model() as m:
with pytest.raises(NotImplementedError):
x = pm.Normal("x", observed=[np.nan, 0, 1], total_size=5)


class TestShared(SeededTest):
def test_deterministic(self):
Expand Down Expand Up @@ -1467,16 +1500,15 @@ def test_tag_future_warning_model():
with pytest.raises(AttributeError):
x.tag.observations

with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"):
with pytest.raises(AttributeError):
total_size = x.tag.total_size
assert total_size is None

# Cloning a node will keep the same tag type and contents
y = x.owner.clone().default_output()
assert y is not x
assert y.tag is not x.tag
assert isinstance(y.tag, _FutureWarningValidatingScratchpad)
y = model.register_rv(y, name="y", observed=5)
y = model.register_rv(y, name="y", observed=5, total_size=7)
assert isinstance(y.tag, _FutureWarningValidatingScratchpad)

# Test expected warnings
Expand All @@ -1486,8 +1518,7 @@ def test_tag_future_warning_model():
y_obs = y.tag.observations
assert y_value is y_obs
assert y_value.eval() == 5

with pytest.warns(FutureWarning, match="model.observed_rvs_to_total_sizes"):
y_total_size = y.tag.total_size
assert y_total_size == 7
assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad)
with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"):
total_size = y.tag.total_size
assert total_size is None
2 changes: 1 addition & 1 deletion pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def __getattribute__(self, name):
for deprecated_names, alternative in (
(("value_var", "observations"), "model.rvs_to_values[rv]"),
(("transform",), "model.rvs_to_transforms[rv]"),
(("total_size",), "model.rvs_to_total_sizes[rv]"),
(("total_size",), "model.observed_rvs_to_total_sizes[rv]"),
):
if name in deprecated_names:
try:
Expand Down
6 changes: 4 additions & 2 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,9 @@ def symbolic_normalizing_constant(self):
t = self.to_flat_input(
at.max(
[
_get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim)
_get_scaling(
self.model.observed_rvs_to_total_sizes.get(v, None), v.shape, v.ndim
)
for v in self.group
]
)
Expand Down Expand Up @@ -1184,7 +1186,7 @@ def symbolic_normalizing_constant(self):
self.collect("symbolic_normalizing_constant")
+ [
_get_scaling(
self.model.rvs_to_total_sizes.get(obs, None),
self.model.observed_rvs_to_total_sizes.get(obs, None),
obs.shape,
obs.ndim,
)
Expand Down