From 4138516b16bdfdbf6e0aa28502b1d75d0a527b20 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 08:44:40 +0100 Subject: [PATCH 01/10] treat vectors on LHS as literals --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bef7d11c2..5fcae7de5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -240,7 +240,7 @@ variables. """ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation - if !(left isa Symbol || left isa Expr) + if !(left isa Symbol || left isa Expr) || Meta.isexpr(left, :vect) return quote $(DynamicPPL.tilde_observe)( __context__, @@ -290,7 +290,7 @@ Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation - if !(left isa Symbol || left isa Expr) + if !(left isa Symbol || left isa Expr) || Meta.isexpr(left, :vect) return quote $(DynamicPPL.dot_tilde_observe)( __context__, From f9ae0a898df14599b2e529f859b2abb261751f0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 08:54:37 +0100 Subject: [PATCH 02/10] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f7a5ba10d..798e29b71 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.11.2" +version = "0.11.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From d76f60b301fa1aad37a49f08becea054967e04d1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:12:55 +0100 Subject: [PATCH 03/10] added test for array literals --- test/compiler.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..822577200 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -423,4 +423,23 @@ end x = [Laplace(), Normal(), MvNormal(3, 1.0)] @test DynamicPPL.check_tilde_rhs(x) === x end + + @testset "array literals" begin + # Verify that we indeed can parse this. + @test @model( + function array_literal_model() + # `assume` and literal `observe` + m ~ MvNormal(2, 1.0) + [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + end + ) isa Function + + @model function array_literal_model() + # `assume` and literal `observe` + m ~ MvNormal(2, 1.0) + [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + end + + @test array_literal_model()() == [10.0, 10.0] + end end From d6a1e428d1fcccdb2679721e8287d44bb904cbab Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:40:44 +0100 Subject: [PATCH 04/10] simplified the isliteral check --- src/compiler.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5fcae7de5..79941f88b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -36,6 +36,15 @@ end # failsafe: a literal is never an assumption isassumption(expr) = :(false) +""" + isliteral(expr) + +Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` otherwise. +""" +isliteral(e) = false +isliteral(::Number) = true +isliteral(e::Expr) = all(isliteral, e.args) + """ check_tilde_rhs(x) @@ -240,7 +249,7 @@ variables. """ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation - if !(left isa Symbol || left isa Expr) || Meta.isexpr(left, :vect) + if isliteral(left) return quote $(DynamicPPL.tilde_observe)( __context__, @@ -290,7 +299,7 @@ Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation - if !(left isa Symbol || left isa Expr) || Meta.isexpr(left, :vect) + if isliteral(left) return quote $(DynamicPPL.dot_tilde_observe)( __context__, From 412548d91bf4d5aee2b8d45fdf3aebbc3d693799 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:41:09 +0100 Subject: [PATCH 05/10] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/compiler.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 822577200..55caaeafc 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -426,18 +426,16 @@ end @testset "array literals" begin # Verify that we indeed can parse this. - @test @model( - function array_literal_model() + @test @model(function array_literal_model() # `assume` and literal `observe` m ~ MvNormal(2, 1.0) - [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) - end - ) isa Function + return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + end) isa Function @model function array_literal_model() # `assume` and literal `observe` m ~ MvNormal(2, 1.0) - [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) end @test array_literal_model()() == [10.0, 10.0] From b993b970d19f20daacbe9026926ac489034aa15a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:43:09 +0100 Subject: [PATCH 06/10] dont allow empty literals --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 79941f88b..2e7a802bb 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -43,7 +43,7 @@ Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` other """ isliteral(e) = false isliteral(::Number) = true -isliteral(e::Expr) = all(isliteral, e.args) +isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) """ check_tilde_rhs(x) From 6bb5412ce8be2fdffae70c00dfe0230d7926981e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 11:36:46 +0100 Subject: [PATCH 07/10] added some tests for isliteral --- test/compiler.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index 55caaeafc..68f54c930 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -423,6 +423,16 @@ end x = [Laplace(), Normal(), MvNormal(3, 1.0)] @test DynamicPPL.check_tilde_rhs(x) === x end + + @testset "isliteral" begin + @test DynamicPPL.isliteral(:([1.0, ])) + @test DynamicPPL.isliteral(:([[1.0,], 1.0])) + @test !(DynamicPPL.isliteral(:((1.0, 1.0)))) + + @test !(DynamicPPL.isliteral(:([x, ]))) + @test !(DynamicPPL.isliteral(:([[x,], 1.0]))) + @test !(DynamicPPL.isliteral(:((x, 1.0)))) + end @testset "array literals" begin # Verify that we indeed can parse this. From d328583b0921869f1dae9d890175b08d2f415d5d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 11:41:54 +0100 Subject: [PATCH 08/10] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/compiler.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 68f54c930..843f34027 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -423,14 +423,13 @@ end x = [Laplace(), Normal(), MvNormal(3, 1.0)] @test DynamicPPL.check_tilde_rhs(x) === x end - @testset "isliteral" begin - @test DynamicPPL.isliteral(:([1.0, ])) - @test DynamicPPL.isliteral(:([[1.0,], 1.0])) + @test DynamicPPL.isliteral(:([1.0])) + @test DynamicPPL.isliteral(:([[1.0], 1.0])) @test !(DynamicPPL.isliteral(:((1.0, 1.0)))) - @test !(DynamicPPL.isliteral(:([x, ]))) - @test !(DynamicPPL.isliteral(:([[x,], 1.0]))) + @test !(DynamicPPL.isliteral(:([x]))) + @test !(DynamicPPL.isliteral(:([[x], 1.0]))) @test !(DynamicPPL.isliteral(:((x, 1.0)))) end From f8593cb259c755c6f543ef272fd8608ea73003be Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 11:42:32 +0100 Subject: [PATCH 09/10] fixed typo in tests --- test/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compiler.jl b/test/compiler.jl index 843f34027..d4518bc64 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -426,7 +426,7 @@ end @testset "isliteral" begin @test DynamicPPL.isliteral(:([1.0])) @test DynamicPPL.isliteral(:([[1.0], 1.0])) - @test !(DynamicPPL.isliteral(:((1.0, 1.0)))) + @test DynamicPPL.isliteral(:((1.0, 1.0))) @test !(DynamicPPL.isliteral(:([x]))) @test !(DynamicPPL.isliteral(:([[x], 1.0]))) From 3810effa591ad0c50d64936f2278e07c01b0bd5d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:27:13 +0100 Subject: [PATCH 10/10] bump Bijectors for tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index d43991c20..afe6ac4cc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0" AbstractPPL = "0.1.3" -Bijectors = "0.8.2, 0.9" +Bijectors = "0.9.5" Distributions = "0.24, 0.25" DistributionsAD = "0.6.3" Documenter = "0.26.1"