From 2e1714f103785abb41e1d6465de41eb58ad97ae7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 12 Jul 2021 23:58:22 +0100 Subject: [PATCH 1/4] added the possibility of forcing a tilde-statement to be an observe-statement --- src/DynamicPPL.jl | 3 ++- src/compiler.jl | 59 +++++++++++++++++++++++++---------------------- src/utils.jl | 38 ++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 29 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..d7e87f50d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -101,7 +101,8 @@ export AbstractVarInfo, pointwise_loglikelihoods, # Convenience macros @addlogprob!, - @submodel + @submodel, + @observe # Reexport using Distributions: loglikelihood diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..caae37b02 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -233,10 +233,10 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +generate_mainbody(mod, expr, warn; kwargs...) = generate_mainbody!(mod, Symbol[], expr, warn; kwargs...) -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn; kwargs...) = x +function generate_mainbody!(mod, found, sym::Symbol, warn; kwargs...) if sym in DEPRECATED_INTERNALNAMES newsym = Symbol(:_, sym, :__) Base.depwarn( @@ -253,13 +253,13 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn; tilde_kwargs...) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn; tilde_kwargs...) end # Modify dotted tilde operators. @@ -268,8 +268,9 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_dottilde return Base.remove_linenums!( generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn; tilde_kwargs...), + generate_mainbody!(mod, found, R, warn; tilde_kwargs...); + tilde_kwargs... ), ) end @@ -280,13 +281,27 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn; tilde_kwargs...), + generate_mainbody!(mod, found, R, warn; tilde_kwargs...); + tilde_kwargs... ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn; tilde_kwargs...), expr.args)...) +end + +""" + generate_tilde(left, right) + +Generate an `observe` expression for literals, e.g. `1.0` and `[1.0, ]`. +""" +function generate_tilde_literal(left, right) + return quote + $(DynamicPPL.tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + ) + end end """ @@ -295,15 +310,9 @@ end Generate an `observe` expression for data variables and `assume` expression for parameter variables. """ -function generate_tilde(left, right) +function generate_tilde(left, right; force_observe=false) # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation @@ -311,7 +320,7 @@ function generate_tilde(left, right) return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(force_observe ? false : DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( __context__, @@ -339,15 +348,9 @@ end Generate the expression that replaces `left .~ right` in the model body. """ -function generate_dot_tilde(left, right) +function generate_dot_tilde(left, right; force_observe=true) # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation @@ -355,7 +358,7 @@ function generate_dot_tilde(left, right) return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(force_observe ? false : DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..62f5927e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,6 +13,44 @@ macro addlogprob!(ex) end end +""" + @observe x ~ dist + @observe x .~ dist + +Force this `~` statement to always be an observe-statement. + +Only usable within the body of [@model](@ref). + +# Examples + +```jldoctest +julia> @model function demo() + x = 1.0 + @observe x ~ Normal() + + return getlogp(__varinfo__) + end +demo (generic function with 1 method) + +julia> demo()() == logpdf(Normal(), 1.0) +true + +julia> @model function demo() + x = [1.0, ] + @observe x .~ Normal() + + return getlogp(__varinfo__) + end +demo (generic function with 1 method) + +julia> demo()() == logpdf(Normal(), 1.0) +true +``` +""" +macro observe(ex) + return esc(generate_mainbody(__module__, ex, false; force_observe=true)) +end + """ getargs_dottilde(x) From 22a053112889a5d09d57fa09191a2fbabed1b484 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jul 2021 00:31:09 +0100 Subject: [PATCH 2/4] allow left to be missing even when forcing observe --- src/compiler.jl | 29 ++++++++++++++++++----------- src/utils.jl | 21 +++++++++++++++------ 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index caae37b02..b7146bded 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -15,14 +15,14 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(expr::Union{Symbol,Expr}) +function isassumption(expr::Union{Symbol,Expr}; check_inargs=true) vn = gensym(:vn) return quote let $vn = $(varname(expr)) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || + if ($check_inargs && !$(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) true else @@ -233,7 +233,9 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn; kwargs...) = generate_mainbody!(mod, Symbol[], expr, warn; kwargs...) +function generate_mainbody(mod, expr, warn; kwargs...) + return generate_mainbody!(mod, Symbol[], expr, warn; kwargs...) +end generate_mainbody!(mod, found, x, warn; kwargs...) = x function generate_mainbody!(mod, found, sym::Symbol, warn; kwargs...) @@ -259,7 +261,9 @@ function generate_mainbody!(mod, found, expr::Expr, warn; tilde_kwargs...) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn; tilde_kwargs...) + return generate_mainbody!( + mod, found, macroexpand(mod, expr; recursive=true), warn; tilde_kwargs... + ) end # Modify dotted tilde operators. @@ -270,7 +274,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn; tilde_kwargs...) generate_dot_tilde( generate_mainbody!(mod, found, L, warn; tilde_kwargs...), generate_mainbody!(mod, found, R, warn; tilde_kwargs...); - tilde_kwargs... + tilde_kwargs..., ), ) end @@ -283,12 +287,15 @@ function generate_mainbody!(mod, found, expr::Expr, warn; tilde_kwargs...) generate_tilde( generate_mainbody!(mod, found, L, warn; tilde_kwargs...), generate_mainbody!(mod, found, R, warn; tilde_kwargs...); - tilde_kwargs... + tilde_kwargs..., ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn; tilde_kwargs...), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found, x, warn; tilde_kwargs...), expr.args)..., + ) end """ @@ -310,7 +317,7 @@ end Generate an `observe` expression for data variables and `assume` expression for parameter variables. """ -function generate_tilde(left, right; force_observe=false) +function generate_tilde(left, right; check_inargs=true) # If the LHS is a literal, it is always an observation isliteral(left) && return generate_tilde_literal(left, right) @@ -320,7 +327,7 @@ function generate_tilde(left, right; force_observe=false) return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(force_observe ? false : DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left; check_inargs=check_inargs)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( __context__, @@ -348,7 +355,7 @@ end Generate the expression that replaces `left .~ right` in the model body. """ -function generate_dot_tilde(left, right; force_observe=true) +function generate_dot_tilde(left, right; check_inargs=true) # If the LHS is a literal, it is always an observation isliteral(left) && return generate_tilde_literal(left, right) @@ -358,7 +365,7 @@ function generate_dot_tilde(left, right; force_observe=true) return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(force_observe ? false : DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left; check_inargs=check_inargs)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, diff --git a/src/utils.jl b/src/utils.jl index 62f5927e0..7b1aeda4f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -17,7 +17,8 @@ end @observe x ~ dist @observe x .~ dist -Force this `~` statement to always be an observe-statement. +Force this `~` statement to be an observe-statement unless `x` is `missing`, +effectively circumventing the check as to whether `x` is the arguments. Only usable within the body of [@model](@ref). @@ -29,8 +30,7 @@ julia> @model function demo() @observe x ~ Normal() return getlogp(__varinfo__) - end -demo (generic function with 1 method) + end; julia> demo()() == logpdf(Normal(), 1.0) true @@ -40,15 +40,24 @@ julia> @model function demo() @observe x .~ Normal() return getlogp(__varinfo__) - end -demo (generic function with 1 method) + end; julia> demo()() == logpdf(Normal(), 1.0) true + +julia> @model function demo(args) + x = args.x + @observe x ~ Normal() + + return getlogp(__varinfo__) + end; + +julia> demo((x = 1.0, ))() == logpdf(Normal(), 1.0) +true ``` """ macro observe(ex) - return esc(generate_mainbody(__module__, ex, false; force_observe=true)) + return esc(generate_mainbody(__module__, ex, false; check_inargs=false)) end """ From 8a341d6f8e7e9707e59497e18cdbecfe14412fb4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 11:47:44 +0100 Subject: [PATCH 3/4] fix doctests --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 7b1aeda4f..4729fcacb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,7 +24,7 @@ Only usable within the body of [@model](@ref). # Examples -```jldoctest +```jldoctest; setup = :(using Distributions) julia> @model function demo() x = 1.0 @observe x ~ Normal() From b822fbd3d6e4d8e9e37bfeebc1b5a5653c754bb3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 12:52:33 +0100 Subject: [PATCH 4/4] added another doctest to observe-macro --- src/utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 4729fcacb..80f6674e1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -54,6 +54,9 @@ julia> @model function demo(args) julia> demo((x = 1.0, ))() == logpdf(Normal(), 1.0) true + +julia> VarInfo(demo((x = missing, )))[@varname(x)] !== missing +true ``` """ macro observe(ex)