Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ def Deterministic(name, var, model=None, dims=None):
return var


def Potential(name, var, model=None):
def Potential(name, var, model=None, dims=None):
"""
Add an arbitrary factor potential to the model likelihood

Expand Down Expand Up @@ -2135,7 +2135,7 @@ def Potential(name, var, model=None):
model = modelcontext(model)
var.name = model.name_for(name)
model.potentials.append(var)
model.add_named_variable(var)
model.add_named_variable(var, dims)

from pymc.printing import str_for_potential_or_deterministic

Expand Down
20 changes: 20 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,26 @@ def test_deterministic():
assert model["y"] == y


def test_determinsitic_with_dims():
"""
Test to check the passing of dims to the potential
"""
with pm.Model(coords={"observed": range(10)}) as model:
x = pm.Normal("x", 0, 1)
y = pm.Deterministic("y", x**2, dims=("observed",))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does not use Potential

Copy link
Contributor Author

@Raj-Parekh24 Raj-Parekh24 Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review @michaelosthege

Sorry for this,

I found that dims test case was not available for deterministic as well so added that test case along with potential.

I might forget to update the test case for potential.

I have updated the code.

Please review it.

assert model.named_vars_to_dims == {"y": ("observed",)}


def test_potential_with_dims():
"""
Test to check the passing of dims to the potential
"""
with pm.Model(coords={"observed": range(10)}) as model:
x = pm.Normal("x", 0, 1)
y = pm.Potential("y", x**2, dims=("observed",))
assert model.named_vars_to_dims == {"y": ("observed",)}


def test_empty_model_representation():
assert pm.Model().str_repr() == ""

Expand Down