From 23a72bc9a454398cd819d4ecfb164b6ea12948d7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 1 Apr 2021 10:41:47 +0200 Subject: [PATCH 01/11] dont touch input of macros unless dot --- src/compiler.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 73ae707c5..c5888ea4f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -177,10 +177,15 @@ function generate_mainbody!(found, expr::Expr, args, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] - # Apply the `@.` macro first. - if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 && - expr.args[1] === Symbol("@__dot__") - return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn) + # Don't touch input of macros, unless it's dot + if Meta.isexpr(expr, :macrocall) + # Apply the `@.` macro first. + if length(expr.args) > 1 && + expr.args[1] === Symbol("@__dot__") + return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn) + else + return expr + end end # Modify dotted tilde operators. From 41cfe7e6c00c7ac0ef92f4725050f53d5bcaf825 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 1 Apr 2021 10:50:47 +0200 Subject: [PATCH 02/11] fixed comment --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index c5888ea4f..4d56c62ba 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -177,7 +177,7 @@ function generate_mainbody!(found, expr::Expr, args, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] - # Don't touch input of macros, unless it's dot + # Don't touch input of macros, unless it's dot. if Meta.isexpr(expr, :macrocall) # Apply the `@.` macro first. if length(expr.args) > 1 && From c420d83ea7b47bbef29cc6844f5f122835bd2d53 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 1 Apr 2021 12:47:10 +0200 Subject: [PATCH 03/11] macros will now be expanded --- src/compiler.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4d56c62ba..4e3aad282 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -64,15 +64,15 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`. macro model(expr, warn=true) # include `LineNumberNode` with information about the call site in the # generated function for easier debugging and interpretation of error messages - esc(model(expr, __source__, warn)) + esc(model(expr, __source__, warn, __module__)) end -function model(expr, linenumbernode, warn) +function model(expr, linenumbernode, warn, mod) modelinfo = build_model_info(expr) # Generate main body modelinfo[:body] = generate_mainbody( - modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn + modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn, mod ) return build_output(modelinfo, linenumbernode) @@ -163,17 +163,17 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(expr, args, warn) = generate_mainbody!(Symbol[], expr, args, warn) +generate_mainbody(expr, args, warn, mod) = generate_mainbody!(Symbol[], expr, args, warn, mod) -generate_mainbody!(found, x, args, warn) = x -function generate_mainbody!(found, sym::Symbol, args, warn) +generate_mainbody!(found, x, args, warn, mod) = x +function generate_mainbody!(found, sym::Symbol, args, warn, mod) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$(sym)`" push!(found, sym) end return sym end -function generate_mainbody!(found, expr::Expr, args, warn) +function generate_mainbody!(found, expr::Expr, args, warn, mod) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] @@ -182,9 +182,9 @@ function generate_mainbody!(found, expr::Expr, args, warn) # Apply the `@.` macro first. if length(expr.args) > 1 && expr.args[1] === Symbol("@__dot__") - return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn) + return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn, mod) else - return expr + return generate_mainbody!(found, macroexpand(mod, expr; recursive=true), args, warn, mod) end end @@ -192,8 +192,8 @@ function generate_mainbody!(found, expr::Expr, args, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return generate_dot_tilde(generate_mainbody!(found, L, args, warn), - generate_mainbody!(found, R, args, warn), + return generate_dot_tilde(generate_mainbody!(found, L, args, warn, mod), + generate_mainbody!(found, R, args, warn, mod), args) |> Base.remove_linenums! end @@ -201,12 +201,12 @@ function generate_mainbody!(found, expr::Expr, args, warn) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde - return generate_tilde(generate_mainbody!(found, L, args, warn), - generate_mainbody!(found, R, args, warn), + return generate_tilde(generate_mainbody!(found, L, args, warn, mod), + generate_mainbody!(found, R, args, warn, mod), args) |> Base.remove_linenums! end - return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn), expr.args)...) + return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn, mod), expr.args)...) end From 5912bc14fa4f156679ba94d0a3a39c6fa1ecc786 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 1 Apr 2021 12:54:31 +0200 Subject: [PATCH 04/11] removed now redundant handling of dot-macro --- src/compiler.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4e3aad282..0820d909d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -177,15 +177,9 @@ function generate_mainbody!(found, expr::Expr, args, warn, mod) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] - # Don't touch input of macros, unless it's dot. + # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - # Apply the `@.` macro first. - if length(expr.args) > 1 && - expr.args[1] === Symbol("@__dot__") - return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn, mod) - else - return generate_mainbody!(found, macroexpand(mod, expr; recursive=true), args, warn, mod) - end + return generate_mainbody!(found, macroexpand(mod, expr; recursive=true), args, warn, mod) end # Modify dotted tilde operators. From 4247a4e8ec522a619e562d18c7f3793964b0e460 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 3 Apr 2021 23:51:27 +0200 Subject: [PATCH 05/11] added tests thanks to @devmotion --- test/compiler.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index 55ad0c706..44cf17b90 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -269,4 +269,29 @@ end end @test isempty(VarInfo(demo_with(0.0))) end + + @testset "macros within model" begin + # Macro expansion + macro mymodel() + return esc(:(x ~ Normal())) + end + + @model function demo() + @mymodel() + end + + @test haskey(VarInfo(demo()), @varname(x)) + + # Interpolation + macro mymodel() + return esc(:(return 42)) + end + + @model function demo() + x ~ Normal() + $(@mymodel()) + end + + @test demo()() == 42 + end end From a3c7f801690b1a9d2f952f401c793cc3b8df567a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 3 Apr 2021 23:52:48 +0200 Subject: [PATCH 06/11] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 232c03392..597246535 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.9" +version = "0.10.10" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 33457365feaa4b5f11c3c93e2da71a90c2979052 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 4 Apr 2021 11:35:38 +0200 Subject: [PATCH 07/11] moved the module and source argument as discussed --- src/compiler.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0820d909d..1de89ef3d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -64,15 +64,15 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`. macro model(expr, warn=true) # include `LineNumberNode` with information about the call site in the # generated function for easier debugging and interpretation of error messages - esc(model(expr, __source__, warn, __module__)) + esc(model(__module__, __source__, expr, warn)) end -function model(expr, linenumbernode, warn, mod) +function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) # Generate main body modelinfo[:body] = generate_mainbody( - modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn, mod + mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn ) return build_output(modelinfo, linenumbernode) @@ -155,7 +155,7 @@ function build_model_info(input_expr) end """ - generate_mainbody(expr, args, warn) + generate_mainbody(mod, expr, args, warn) Generate the body of the main evaluation function from expression `expr` and arguments `args`. @@ -163,31 +163,31 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(expr, args, warn, mod) = generate_mainbody!(Symbol[], expr, args, warn, mod) +generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn) -generate_mainbody!(found, x, args, warn, mod) = x -function generate_mainbody!(found, sym::Symbol, args, warn, mod) +generate_mainbody!(mod, found, x, args, warn) = x +function generate_mainbody!(mod, found, sym::Symbol, args, warn) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$(sym)`" push!(found, sym) end return sym end -function generate_mainbody!(found, expr::Expr, args, warn, mod) +function generate_mainbody!(mod, found, expr::Expr, args, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(found, macroexpand(mod, expr; recursive=true), args, warn, mod) + return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn) end # Modify dotted tilde operators. args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return generate_dot_tilde(generate_mainbody!(found, L, args, warn, mod), - generate_mainbody!(found, R, args, warn, mod), + return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn), + generate_mainbody!(mod, found, R, args, warn), args) |> Base.remove_linenums! end @@ -195,12 +195,12 @@ function generate_mainbody!(found, expr::Expr, args, warn, mod) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde - return generate_tilde(generate_mainbody!(found, L, args, warn, mod), - generate_mainbody!(found, R, args, warn, mod), + return generate_tilde(generate_mainbody!(mod, found, L, args, warn), + generate_mainbody!(mod, found, R, args, warn), args) |> Base.remove_linenums! end - return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn, mod), expr.args)...) + return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...) end From 1406cd07ac7e50e0b69082e30c64bf65be9a40d1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 4 Apr 2021 11:37:08 +0200 Subject: [PATCH 08/11] Update test/compiler.jl Co-authored-by: David Widmann --- test/compiler.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 44cf17b90..1e72f7e3a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -272,12 +272,17 @@ end @testset "macros within model" begin # Macro expansion - macro mymodel() - return esc(:(x ~ Normal())) + macro mymodel(ex) + # check if expression was modified by the DynamicPPL "compiler" + if ex == :(y ~ Uniform()) + return esc(:(x ~ Normal())) + else + return esc(:(z ~ Exponential())) + end end @model function demo() - @mymodel() + @mymodel(y ~ Uniform()) end @test haskey(VarInfo(demo()), @varname(x)) From 6bb810424ec242102eeeda71d1f491ccb5248ecd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 4 Apr 2021 11:49:03 +0200 Subject: [PATCH 09/11] improved testing according to @devmotion suggestion --- test/compiler.jl | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 1e72f7e3a..45eb40c82 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -275,10 +275,10 @@ end macro mymodel(ex) # check if expression was modified by the DynamicPPL "compiler" if ex == :(y ~ Uniform()) - return esc(:(x ~ Normal())) - else - return esc(:(z ~ Exponential())) - end + return esc(:(x ~ Normal())) + else + return esc(:(z ~ Exponential())) + end end @model function demo() @@ -288,15 +288,28 @@ end @test haskey(VarInfo(demo()), @varname(x)) # Interpolation - macro mymodel() - return esc(:(return 42)) + # Will fail if: + # 1. Compiler expands `y ~ Uniform()` before expanding the macros + # => returns -1. + # 2. `@mymodel` is expanded before entire `@model` has been + # expanded => errors since `MyModelStruct` is not a distribution, + # and hence `tilde_observe` errors. + struct MyModelStruct{T} + x::T + end + Base.:~(x, y::MyModelStruct) = y.x + macro mymodel(ex) + # check if expression was modified by the DynamicPPL "compiler" + if ex == :(y ~ Uniform()) + # Just returns 42 + return :(4 ~ MyModelStruct(42)) + else + return :(return -1) + end end - @model function demo() - x ~ Normal() - $(@mymodel()) + $(@mymodel(y ~ Uniform())) end - @test demo()() == 42 end end From b9a4840d9df91b1d30de09be4ad979a21326c8e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 4 Apr 2021 11:59:04 +0200 Subject: [PATCH 10/11] fixed indentation --- test/compiler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 45eb40c82..da065da53 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -275,10 +275,10 @@ end macro mymodel(ex) # check if expression was modified by the DynamicPPL "compiler" if ex == :(y ~ Uniform()) - return esc(:(x ~ Normal())) - else - return esc(:(z ~ Exponential())) - end + return esc(:(x ~ Normal())) + else + return esc(:(z ~ Exponential())) + end end @model function demo() From 57d7fa8147a61f91753b79aa2916794ef00f1c93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 4 Apr 2021 12:20:21 +0200 Subject: [PATCH 11/11] fixed tests for macro in model --- test/compiler.jl | 49 ++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index da065da53..1fffd50ce 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -6,6 +6,29 @@ macro custom(expr) end end +macro mymodel1(ex) + # check if expression was modified by the DynamicPPL "compiler" + if ex == :(y ~ Uniform()) + return esc(:(x ~ Normal())) + else + return esc(:(z ~ Exponential())) + end +end + +struct MyModelStruct{T} + x::T +end +Base.:~(x, y::MyModelStruct) = y.x +macro mymodel2(ex) + # check if expression was modified by the DynamicPPL "compiler" + if ex == :(y ~ Uniform()) + # Just returns 42 + return :(4 ~ MyModelStruct(42)) + else + return :(return -1) + end +end + @testset "compiler.jl" begin @testset "model macro" begin @model function testmodel_comp(x, y) @@ -272,17 +295,8 @@ end @testset "macros within model" begin # Macro expansion - macro mymodel(ex) - # check if expression was modified by the DynamicPPL "compiler" - if ex == :(y ~ Uniform()) - return esc(:(x ~ Normal())) - else - return esc(:(z ~ Exponential())) - end - end - @model function demo() - @mymodel(y ~ Uniform()) + @mymodel1(y ~ Uniform()) end @test haskey(VarInfo(demo()), @varname(x)) @@ -294,21 +308,8 @@ end # 2. `@mymodel` is expanded before entire `@model` has been # expanded => errors since `MyModelStruct` is not a distribution, # and hence `tilde_observe` errors. - struct MyModelStruct{T} - x::T - end - Base.:~(x, y::MyModelStruct) = y.x - macro mymodel(ex) - # check if expression was modified by the DynamicPPL "compiler" - if ex == :(y ~ Uniform()) - # Just returns 42 - return :(4 ~ MyModelStruct(42)) - else - return :(return -1) - end - end @model function demo() - $(@mymodel(y ~ Uniform())) + $(@mymodel2(y ~ Uniform())) end @test demo()() == 42 end