From 72f16ee19d6e41cd94e152dbfe67a237cd13621c Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Wed, 8 Mar 2023 00:26:32 +0530 Subject: [PATCH 1/3] Adding dims parameter in Potential --- pymc/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index a9e55ea031..2d12ca745d 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -2041,7 +2041,7 @@ def Deterministic(name, var, model=None, dims=None): return var -def Potential(name, var, model=None): +def Potential(name, var, model=None, dims=None): """ Add an arbitrary factor potential to the model likelihood @@ -2135,7 +2135,7 @@ def Potential(name, var, model=None): model = modelcontext(model) var.name = model.name_for(name) model.potentials.append(var) - model.add_named_variable(var) + model.add_named_variable(var, dims) from pymc.printing import str_for_potential_or_deterministic From a942b6ea34d1d6f5f6a84cf41fb4d6ca3d99d0a9 Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Wed, 8 Mar 2023 23:42:45 +0530 Subject: [PATCH 2/3] Test case to check the dims in potential --- tests/test_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 2589e23c4c..9cef68feb5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1062,6 +1062,16 @@ def test_deterministic(): assert model["y"] == y +def test_potential_with_dims(): + """ + Test to check the passing of dims to the potential + """ + with pm.Model(coords={"observed": range(10)}) as model: + x = pm.Normal("x", 0, 1) + y = pm.Deterministic("y", x**2, dims=("observed",)) + assert model.named_vars_to_dims == {"y": ("observed",)} + + def test_empty_model_representation(): assert pm.Model().str_repr() == "" From 152626119bedbbf9d19c1ff9c0347de1d3e2b7de Mon Sep 17 00:00:00 2001 From: Raj-Parekh24 Date: Thu, 9 Mar 2023 07:30:43 +0530 Subject: [PATCH 3/3] Fixes in test caas --- tests/test_model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 9cef68feb5..a9e6aeda52 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1062,7 +1062,7 @@ def test_deterministic(): assert model["y"] == y -def test_potential_with_dims(): +def test_determinsitic_with_dims(): """ Test to check the passing of dims to the potential """ @@ -1072,6 +1072,16 @@ def test_potential_with_dims(): assert model.named_vars_to_dims == {"y": ("observed",)} +def test_potential_with_dims(): + """ + Test to check the passing of dims to the potential + """ + with pm.Model(coords={"observed": range(10)}) as model: + x = pm.Normal("x", 0, 1) + y = pm.Potential("y", x**2, dims=("observed",)) + assert model.named_vars_to_dims == {"y": ("observed",)} + + def test_empty_model_representation(): assert pm.Model().str_repr() == ""