Skip to content

Commit 5468499

Browse files
committed
Only add total_size mapping to observed RVs
* Add direct test * Add checks for when it's not applicable * Rename mapping to `observed_rvs_to_total_sizes`
1 parent eb16ce6 commit 5468499

File tree

4 files changed

+76
-32
lines changed

4 files changed

+76
-32
lines changed

pymc/model.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,10 @@ def __init__(
556556
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
557557
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
558558
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
559-
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
560559
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
560+
self.observed_rvs_to_total_sizes = treedict(
561+
parent=self.parent.observed_rvs_to_total_sizes
562+
)
561563
self.free_RVs = treelist(parent=self.parent.free_RVs)
562564
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
563565
self.deterministics = treelist(parent=self.parent.deterministics)
@@ -570,8 +572,8 @@ def __init__(
570572
self.values_to_rvs = treedict()
571573
self.rvs_to_values = treedict()
572574
self.rvs_to_transforms = treedict()
573-
self.rvs_to_total_sizes = treedict()
574575
self.rvs_to_initial_values = treedict()
576+
self.observed_rvs_to_total_sizes = treedict()
575577
self.free_RVs = treelist()
576578
self.observed_RVs = treelist()
577579
self.deterministics = treelist()
@@ -751,7 +753,7 @@ def logp(
751753
rvs=rvs,
752754
rvs_to_values=self.rvs_to_values,
753755
rvs_to_transforms=self.rvs_to_transforms,
754-
rvs_to_total_sizes=self.rvs_to_total_sizes,
756+
rvs_to_total_sizes=self.observed_rvs_to_total_sizes,
755757
jacobian=jacobian,
756758
)
757759
assert isinstance(rv_logps, list)
@@ -1289,8 +1291,6 @@ def register_rv(
12891291
name = self.name_for(name)
12901292
rv_var.name = name
12911293
_add_future_warning_tag(rv_var)
1292-
rv_var.tag.total_size = total_size
1293-
self.rvs_to_total_sizes[rv_var] = total_size
12941294

12951295
# Associate previously unknown dimension names with
12961296
# the length of the corresponding RV dimension.
@@ -1300,6 +1300,8 @@ def register_rv(
13001300
self.add_coord(dname, values=None, length=rv_var.shape[d])
13011301

13021302
if observed is None:
1303+
if total_size is not None:
1304+
raise ValueError("total_size can only be used for observed RVs")
13031305
self.free_RVs.append(rv_var)
13041306
self.create_value_var(rv_var, transform)
13051307
self.add_named_variable(rv_var, dims)
@@ -1323,12 +1325,20 @@ def register_rv(
13231325

13241326
# `rv_var` is potentially changed by `make_obs_var`,
13251327
# for example into a new graph for imputation of missing data.
1326-
rv_var = self.make_obs_var(rv_var, observed, dims, transform)
1328+
rv_var = self.make_obs_var(
1329+
rv_var, observed, total_size=total_size, dims=dims, transform=transform
1330+
)
13271331

13281332
return rv_var
13291333

13301334
def make_obs_var(
1331-
self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any]
1335+
self,
1336+
rv_var: TensorVariable,
1337+
data: np.ndarray,
1338+
*,
1339+
total_size,
1340+
dims,
1341+
transform: Optional[Any],
13321342
) -> TensorVariable:
13331343
"""Create a `TensorVariable` for an observed random variable.
13341344
@@ -1364,19 +1374,16 @@ def make_obs_var(
13641374

13651375
mask = getattr(data, "mask", None)
13661376
if mask is not None:
1367-
1368-
if mask.all():
1369-
# If there are no observed values, this variable isn't really
1370-
# observed.
1371-
return rv_var
1372-
13731377
impute_message = (
13741378
f"Data in {rv_var} contains missing values and"
13751379
" will be automatically imputed from the"
13761380
" sampling distribution."
13771381
)
13781382
warnings.warn(impute_message, ImputationWarning)
13791383

1384+
if total_size is not None:
1385+
raise NotImplementedError("total_size cannot be used with imputation")
1386+
13801387
if not isinstance(rv_var.owner.op, RandomVariable):
13811388
raise NotImplementedError(
13821389
"Automatic inputation is only supported for univariate RandomVariables."
@@ -1431,6 +1438,8 @@ def make_obs_var(
14311438
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
14321439
self.add_named_variable(observed_rv_var)
14331440
self.observed_RVs.append(observed_rv_var)
1441+
self.observed_rvs_to_total_sizes[observed_rv_var] = None
1442+
observed_rv_var.tag.total_size = None
14341443

14351444
# Create deterministic that combines observed and missing
14361445
# Note: This can widely increase memory consumption during sampling for large datasets
@@ -1448,6 +1457,8 @@ def make_obs_var(
14481457
self.create_value_var(rv_var, transform=None, value_var=data)
14491458
self.add_named_variable(rv_var, dims)
14501459
self.observed_RVs.append(rv_var)
1460+
self.observed_rvs_to_total_sizes[rv_var] = total_size
1461+
rv_var.tag.total_size = total_size
14511462

14521463
return rv_var
14531464

pymc/tests/test_model.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def test_model_value_vars():
500500
def test_model_var_maps():
501501
with pm.Model() as model:
502502
a = pm.Uniform("a")
503-
x = pm.Normal("x", a, total_size=5)
503+
x = pm.Normal("x", a)
504504

505505
assert set(model.rvs_to_values.keys()) == {a, x}
506506
a_value = model.rvs_to_values[a]
@@ -512,10 +512,7 @@ def test_model_var_maps():
512512
assert set(model.rvs_to_transforms.keys()) == {a, x}
513513
assert isinstance(model.rvs_to_transforms[a], IntervalTransform)
514514
assert model.rvs_to_transforms[x] is None
515-
516-
assert set(model.rvs_to_total_sizes.keys()) == {a, x}
517-
assert model.rvs_to_total_sizes[a] is None
518-
assert model.rvs_to_total_sizes[x] == 5
515+
assert model.observed_rvs_to_total_sizes == {}
519516

520517

521518
def test_make_obs_var():
@@ -538,27 +535,28 @@ def test_make_obs_var():
538535
# Create the testval attribute simply for the sake of model testing
539536
fake_distribution.name = input_name
540537

538+
kwargs = dict(total_size=None, dims=None, transform=None)
541539
# The function requires data and RV dimensionality to be compatible
542540
with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."):
543-
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None)
541+
fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), **kwargs)
544542

545543
# Check function behavior using the various inputs
546544
# dense, sparse: Ensure that the missing values are appropriately set to None
547545
# masked: a deterministic variable is returned
548546

549-
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None)
547+
dense_output = fake_model.make_obs_var(fake_distribution, dense_input, **kwargs)
550548
assert dense_output == fake_distribution
551549
assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant)
552550
del fake_model.named_vars[fake_distribution.name]
553551

554-
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None)
552+
sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, **kwargs)
555553
assert sparse_output == fake_distribution
556554
assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output])
557555
del fake_model.named_vars[fake_distribution.name]
558556

559557
# Here the RandomVariable is split into observed/imputed and a Deterministic is returned
560558
with pytest.warns(ImputationWarning):
561-
masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, None, None)
559+
masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, **kwargs)
562560
assert masked_output != fake_distribution
563561
assert not isinstance(masked_output, RandomVariable)
564562
# Ensure it has missing values
@@ -701,6 +699,29 @@ def test_set_initval():
701699
assert y in model.initial_values
702700

703701

702+
class TestTotalSize:
703+
def test_total_size_univariate(self):
704+
with pm.Model() as m:
705+
x = pm.Normal("x", observed=[0, 0], total_size=7)
706+
assert m.observed_rvs_to_total_sizes[x] == 7
707+
708+
m.compile_logp()({}) == st.norm().logpdf(0) * 7
709+
710+
def test_total_size_multivariate(self):
711+
with pm.Model() as m:
712+
x = pm.MvNormal("x", np.ones(3), np.eye(3), observed=np.zeros((2, 3)), total_size=7)
713+
assert m.observed_rvs_to_total_sizes[x] == 7
714+
715+
m.compile_logp()({}) == st.multivariate_normal.logpdf(
716+
np.zeros(3), np.ones(3), np.eye(3)
717+
) * 7
718+
719+
def test_total_size_error(self):
720+
with pm.Model():
721+
with pytest.raises(ValueError, match="total_size can only be used for observed RVs"):
722+
pm.Normal("x", total_size=7)
723+
724+
704725
def test_datalogp_multiple_shapes():
705726
with pm.Model() as m:
706727
x = pm.Normal("x", 0, 1)
@@ -1425,6 +1446,18 @@ def test_error_non_random_variable(self):
14251446
observed=data,
14261447
)
14271448

1449+
def test_rvs_to_total_sizes(self):
1450+
with pm.Model() as m:
1451+
x = pm.Normal("x", observed=[np.nan, 0, 1])
1452+
assert m["x"] not in m.observed_rvs_to_total_sizes
1453+
assert m["x_missing"] not in m.observed_rvs_to_total_sizes
1454+
assert m.observed_rvs_to_total_sizes[m["x_observed"]] is None
1455+
1456+
def test_total_size_not_supported(self):
1457+
with pm.Model() as m:
1458+
with pytest.raises(NotImplementedError):
1459+
x = pm.Normal("x", observed=[np.nan, 0, 1], total_size=5)
1460+
14281461

14291462
class TestShared(SeededTest):
14301463
def test_deterministic(self):
@@ -1467,16 +1500,15 @@ def test_tag_future_warning_model():
14671500
with pytest.raises(AttributeError):
14681501
x.tag.observations
14691502

1470-
with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"):
1503+
with pytest.raises(AttributeError):
14711504
total_size = x.tag.total_size
1472-
assert total_size is None
14731505

14741506
# Cloning a node will keep the same tag type and contents
14751507
y = x.owner.clone().default_output()
14761508
assert y is not x
14771509
assert y.tag is not x.tag
14781510
assert isinstance(y.tag, _FutureWarningValidatingScratchpad)
1479-
y = model.register_rv(y, name="y", observed=5)
1511+
y = model.register_rv(y, name="y", observed=5, total_size=7)
14801512
assert isinstance(y.tag, _FutureWarningValidatingScratchpad)
14811513

14821514
# Test expected warnings
@@ -1486,8 +1518,7 @@ def test_tag_future_warning_model():
14861518
y_obs = y.tag.observations
14871519
assert y_value is y_obs
14881520
assert y_value.eval() == 5
1489-
1521+
with pytest.warns(FutureWarning, match="model.observed_rvs_to_total_sizes"):
1522+
y_total_size = y.tag.total_size
1523+
assert y_total_size == 7
14901524
assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad)
1491-
with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"):
1492-
total_size = y.tag.total_size
1493-
assert total_size is None

pymc/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def __getattribute__(self, name):
490490
for deprecated_names, alternative in (
491491
(("value_var", "observations"), "model.rvs_to_values[rv]"),
492492
(("transform",), "model.rvs_to_transforms[rv]"),
493-
(("total_size",), "model.rvs_to_total_sizes[rv]"),
493+
(("total_size",), "model.observed_rvs_to_total_sizes[rv]"),
494494
):
495495
if name in deprecated_names:
496496
try:

pymc/variational/opvi.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,9 @@ def symbolic_normalizing_constant(self):
10461046
t = self.to_flat_input(
10471047
at.max(
10481048
[
1049-
_get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim)
1049+
_get_scaling(
1050+
self.model.observed_rvs_to_total_sizes.get(v, None), v.shape, v.ndim
1051+
)
10501052
for v in self.group
10511053
]
10521054
)
@@ -1184,7 +1186,7 @@ def symbolic_normalizing_constant(self):
11841186
self.collect("symbolic_normalizing_constant")
11851187
+ [
11861188
_get_scaling(
1187-
self.model.rvs_to_total_sizes.get(obs, None),
1189+
self.model.observed_rvs_to_total_sizes.get(obs, None),
11881190
obs.shape,
11891191
obs.ndim,
11901192
)

0 commit comments

Comments
 (0)