From 5741d6797e1e8a69c8d692a713bd2ba821c624ae Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sat, 21 Aug 2021 17:23:38 +0200 Subject: [PATCH 1/4] Get rid of repeated construction of varname lenses --- src/compiler.jl | 64 ++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 25530696d..1eb481cc7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -15,34 +15,30 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) - +function isassumption(vn::Symbol, expr::Union{Expr,Symbol}) return quote - let $vn = $(AbstractPPL.drop_escape(varname(expr))) - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) - # Considered an assumption by `__context__` which means either: - # 1. We hit the default implementation, e.g. using `DefaultContext`, - # which in turn means that we haven't considered if it's one of - # the model arguments, hence we need to check this. - # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, - # i.e. we're trying to condition one of the latent variables. - # In this case, the below will return `true` since the first branch - # will be hit. - # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, - # i.e. we're trying to override the value. This is currently NOT supported. - # TODO: Support by adding context to model, and use `model.args` - # as the default conditioning. Then we no longer need to check `inargnames` - # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - $(maybe_view(expr)) === missing - end + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) + true else - false + $(maybe_view(expr)) === missing end + else + false end end end @@ -81,7 +77,7 @@ function contextual_isassumption(context::PrefixContext, vn) end # failsafe: a literal is never an assumption -isassumption(expr) = :(false) +isassumption(expr, vn) = :(false) # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x @@ -396,7 +392,7 @@ function generate_tilde(left, right) # more selective with our escape. Until that's the case, we remove them all. return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_tilde_assume(left, right, vn)) else @@ -417,8 +413,8 @@ function generate_tilde(left, right) end function generate_tilde_assume(left, right, vn) - expr = :( - $left = $(DynamicPPL.tilde_assume!)( + new_right = :( + $(DynamicPPL.tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, @@ -426,11 +422,15 @@ function generate_tilde_assume(left, right, vn) ) return if left isa Expr + @gensym lens AbstractPPL.drop_escape( - Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + quote + $lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn)) + $left = $(Setfield.set)($left, $lens, $new_right) + end ) else - return expr + return :($left = $new_right) end end @@ -447,7 +447,7 @@ function generate_dot_tilde(left, right) @gensym vn isassumption return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) else From 47679d4269bd8471f26586faeb847d31e2d09294 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 30 Aug 2021 10:50:11 +0200 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde --- src/compiler.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1eb481cc7..11cd70a36 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -32,7 +32,7 @@ function isassumption(vn::Symbol, expr::Union{Expr,Symbol}) # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) true else $(maybe_view(expr)) === missing @@ -77,6 +77,7 @@ function contextual_isassumption(context::PrefixContext, vn) end # failsafe: a literal is never an assumption +isassumption(expr) = :(false) isassumption(expr, vn) = :(false) # If we're working with, say, a `Symbol`, then we're not going to `view`. @@ -413,7 +414,7 @@ function generate_tilde(left, right) end function generate_tilde_assume(left, right, vn) - new_right = :( + tilde = :( $(DynamicPPL.tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., @@ -426,11 +427,11 @@ function generate_tilde_assume(left, right, vn) AbstractPPL.drop_escape( quote $lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn)) - $left = $(Setfield.set)($left, $lens, $new_right) - end + $left = $(Setfield.set)($left, $lens, $tilde) + end, ) else - return :($left = $new_right) + return :($left = $tilde) end end From 7f27aedf125ef1ee9c5175ae971b461fffbce13e Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 30 Aug 2021 11:21:50 +0200 Subject: [PATCH 3/4] Apply suggested edits & fix bugs --- src/compiler.jl | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 11cd70a36..90820e809 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2,7 +2,7 @@ const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__ const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) """ - isassumption(expr) + isassumption(expr, vn) Return an expression that can be evaluated to check if `expr` is an assumption in the model. @@ -15,7 +15,7 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(vn::Symbol, expr::Union{Expr,Symbol}) +function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr))) return quote if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: @@ -43,6 +43,10 @@ function isassumption(vn::Symbol, expr::Union{Expr,Symbol}) end end +# failsafe: a literal is never an assumption +isassumption(expr, vn) = :(false) +isassumption(expr) = :(false) + """ contextual_isassumption(context, vn) @@ -76,10 +80,6 @@ function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) end -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) -isassumption(expr, vn) = :(false) - # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@views($x))) @@ -423,13 +423,14 @@ function generate_tilde_assume(left, right, vn) ) return if left isa Expr + # `x[i] = ...` needs to become `x = set(x, @lens(_[i]), ...)` @gensym lens - AbstractPPL.drop_escape( - quote - $lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn)) - $left = $(Setfield.set)($left, $lens, $tilde) - end, - ) + # TODO: maybe export this from AbstractPPL again... + vn_name = AbstractPPL.vsym(left) + quote + $lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn)) + $vn_name = $(Setfield.set)($vn_name, $lens, $tilde) + end else return :($left = $tilde) end From 0ca51699e5db588f8924aedf27a09c36643a0d76 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 30 Aug 2021 11:25:12 +0200 Subject: [PATCH 4/4] Remove old now misleading comment --- src/compiler.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 90820e809..829b08125 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -425,7 +425,6 @@ function generate_tilde_assume(left, right, vn) return if left isa Expr # `x[i] = ...` needs to become `x = set(x, @lens(_[i]), ...)` @gensym lens - # TODO: maybe export this from AbstractPPL again... vn_name = AbstractPPL.vsym(left) quote $lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn))