Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.9"
version = "0.10.10"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
33 changes: 16 additions & 17 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
macro model(expr, warn=true)
# include `LineNumberNode` with information about the call site in the
# generated function for easier debugging and interpretation of error messages
esc(model(expr, __source__, warn))
esc(model(__module__, __source__, expr, warn))
end

function model(expr, linenumbernode, warn)
function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:body] = generate_mainbody(
modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
)

return build_output(modelinfo, linenumbernode)
Expand Down Expand Up @@ -155,53 +155,52 @@ function build_model_info(input_expr)
end

"""
generate_mainbody(expr, args, warn)
generate_mainbody(mod, expr, args, warn)

Generate the body of the main evaluation function from expression `expr` and arguments
`args`.

If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(expr, args, warn) = generate_mainbody!(Symbol[], expr, args, warn)
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)

generate_mainbody!(found, x, args, warn) = x
function generate_mainbody!(found, sym::Symbol, args, warn)
generate_mainbody!(mod, found, x, args, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$(sym)`"
push!(found, sym)
end
return sym
end
function generate_mainbody!(found, expr::Expr, args, warn)
function generate_mainbody!(mod, found, expr::Expr, args, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Apply the `@.` macro first.
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
expr.args[1] === Symbol("@__dot__")
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn)
# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_dot_tilde(generate_mainbody!(found, L, args, warn),
generate_mainbody!(found, R, args, warn),
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return generate_tilde(generate_mainbody!(found, L, args, warn),
generate_mainbody!(found, R, args, warn),
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
end

return Expr(expr.head, map(x -> generate_mainbody!(found, x, args, warn), expr.args)...)
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
end


Expand Down
44 changes: 44 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@ macro custom(expr)
end
end

macro mymodel1(ex)
# check if expression was modified by the DynamicPPL "compiler"
if ex == :(y ~ Uniform())
return esc(:(x ~ Normal()))
else
return esc(:(z ~ Exponential()))
end
end

struct MyModelStruct{T}
x::T
end
Base.:~(x, y::MyModelStruct) = y.x
macro mymodel2(ex)
# check if expression was modified by the DynamicPPL "compiler"
if ex == :(y ~ Uniform())
# Just returns 42
return :(4 ~ MyModelStruct(42))
else
return :(return -1)
end
end

@testset "compiler.jl" begin
@testset "model macro" begin
@model function testmodel_comp(x, y)
Expand Down Expand Up @@ -269,4 +292,25 @@ end
end
@test isempty(VarInfo(demo_with(0.0)))
end

@testset "macros within model" begin
# Macro expansion
@model function demo()
@mymodel1(y ~ Uniform())
end

@test haskey(VarInfo(demo()), @varname(x))

# Interpolation
# Will fail if:
# 1. Compiler expands `y ~ Uniform()` before expanding the macros
# => returns -1.
# 2. `@mymodel` is expanded before entire `@model` has been
# expanded => errors since `MyModelStruct` is not a distribution,
# and hence `tilde_observe` errors.
@model function demo()
$(@mymodel2(y ~ Uniform()))
end
@test demo()() == 42
end
end