Skip to content

Commit bfec9c6

Browse files
authored
Try #221:
2 parents 068e5d3 + f2c180d commit bfec9c6

File tree

3 files changed

+43
-57
lines changed

3 files changed

+43
-57
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- uses: actions/checkout@v2
2525
- uses: julia-actions/setup-julia@v1
2626
with:
27-
version: 1
27+
version: 1.5
2828
arch: x64
2929
- uses: julia-actions/julia-buildpkg@latest
3030
- name: Clone Downstream

src/compiler.jl

Lines changed: 38 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function model(mod, linenumbernode, expr, warn)
7272

7373
# Generate main body
7474
modelinfo[:body] = generate_mainbody(
75-
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
75+
mod, modelinfo[:modeldef][:body], warn
7676
)
7777

7878
return build_output(modelinfo, linenumbernode)
@@ -155,92 +155,84 @@ function build_model_info(input_expr)
155155
end
156156

157157
"""
158-
generate_mainbody(mod, expr, args, warn)
158+
generate_mainbody(mod, expr, warn)
159159
160160
Generate the body of the main evaluation function from expression `expr` and arguments
161161
`args`.
162162
163163
If `warn` is true, a warning is displayed if internal variables are used in the model
164164
definition.
165165
"""
166-
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)
166+
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
167167

168-
generate_mainbody!(mod, found, x, args, warn) = x
169-
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
168+
generate_mainbody!(mod, found, x, warn) = x
169+
function generate_mainbody!(mod, found, sym::Symbol, warn)
170170
if warn && sym in INTERNALNAMES && sym found
171171
@warn "you are using the internal variable `$(sym)`"
172172
push!(found, sym)
173173
end
174174
return sym
175175
end
176-
function generate_mainbody!(mod, found, expr::Expr, args, warn)
176+
function generate_mainbody!(mod, found, expr::Expr, warn)
177177
# Do not touch interpolated expressions
178178
expr.head === :$ && return expr.args[1]
179179

180180
# If it's a macro, we expand it
181181
if Meta.isexpr(expr, :macrocall)
182-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
182+
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
183183
end
184184

185185
# Modify dotted tilde operators.
186186
args_dottilde = getargs_dottilde(expr)
187187
if args_dottilde !== nothing
188188
L, R = args_dottilde
189-
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
190-
generate_mainbody!(mod, found, R, args, warn),
191-
args) |> Base.remove_linenums!
189+
return generate_dot_tilde(
190+
generate_mainbody!(mod, found, L, warn),
191+
generate_mainbody!(mod, found, R, warn),
192+
) |> Base.remove_linenums!
192193
end
193194

194195
# Modify tilde operators.
195196
args_tilde = getargs_tilde(expr)
196197
if args_tilde !== nothing
197198
L, R = args_tilde
198-
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
199-
generate_mainbody!(mod, found, R, args, warn),
200-
args) |> Base.remove_linenums!
199+
return generate_tilde(
200+
generate_mainbody!(mod, found, L, warn),
201+
generate_mainbody!(mod, found, R, warn),
202+
) |> Base.remove_linenums!
201203
end
202204

203-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
205+
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
204206
end
205207

206208

207209

208210
"""
209-
generate_tilde(left, right, args)
211+
generate_tilde(left, right)
210212
211213
Generate an `observe` expression for data variables and `assume` expression for parameter
212214
variables.
213215
"""
214-
function generate_tilde(left, right, args)
216+
function generate_tilde(left, right)
215217
@gensym tmpright
216218
top = [:($tmpright = $right),
217219
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
218220
|| throw(ArgumentError($DISTMSG)))]
219221

220222
if left isa Symbol || left isa Expr
221-
@gensym out vn inds
223+
@gensym out vn inds isassumption
222224
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
223225

224-
# It can only be an observation if the LHS is an argument of the model
225-
if vsym(left) in args
226-
@gensym isassumption
227-
return quote
228-
$(top...)
229-
$isassumption = $(DynamicPPL.isassumption(left))
230-
if $isassumption
231-
$left = $(DynamicPPL.tilde_assume)(
232-
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
233-
else
234-
$(DynamicPPL.tilde_observe)(
235-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
236-
end
237-
end
238-
end
239-
240226
return quote
241227
$(top...)
242-
$left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn,
243-
$inds, _varinfo)
228+
$isassumption = $(DynamicPPL.isassumption(left))
229+
if $isassumption
230+
$left = $(DynamicPPL.tilde_assume)(
231+
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
232+
else
233+
$(DynamicPPL.tilde_observe)(
234+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
235+
end
244236
end
245237
end
246238

@@ -252,40 +244,30 @@ function generate_tilde(left, right, args)
252244
end
253245

254246
"""
255-
generate_dot_tilde(left, right, args)
247+
generate_dot_tilde(left, right)
256248
257249
Generate the expression that replaces `left .~ right` in the model body.
258250
"""
259-
function generate_dot_tilde(left, right, args)
251+
function generate_dot_tilde(left, right)
260252
@gensym tmpright
261253
top = [:($tmpright = $right),
262254
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
263255
|| throw(ArgumentError($DISTMSG)))]
264256

265257
if left isa Symbol || left isa Expr
266-
@gensym out vn inds
258+
@gensym out vn inds isassumption
267259
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
268260

269-
# It can only be an observation if the LHS is an argument of the model
270-
if vsym(left) in args
271-
@gensym isassumption
272-
return quote
273-
$(top...)
274-
$isassumption = $(DynamicPPL.isassumption(left))
275-
if $isassumption
276-
$left .= $(DynamicPPL.dot_tilde_assume)(
277-
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
278-
else
279-
$(DynamicPPL.dot_tilde_observe)(
280-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
281-
end
282-
end
283-
end
284-
285261
return quote
286262
$(top...)
287-
$left .= $(DynamicPPL.dot_tilde_assume)(
288-
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
263+
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
264+
if $isassumption
265+
$left .= $(DynamicPPL.dot_tilde_assume)(
266+
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
267+
else
268+
$(DynamicPPL.dot_tilde_observe)(
269+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
270+
end
289271
end
290272
end
291273

src/varname.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,7 @@ Possibly existing indices of `varname` are neglected.
3838
@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
3939
return s in missings
4040
end
41+
42+
@generated function inmissings(::Val{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
43+
return s in missings
44+
end

0 commit comments

Comments
 (0)