@@ -72,7 +72,7 @@ function model(mod, linenumbernode, expr, warn)
7272
7373 # Generate main body
7474 modelinfo[:body ] = generate_mainbody (
75- mod, modelinfo[:modeldef ][:body ], modelinfo[ :allargs_syms ], warn
75+ mod, modelinfo[:modeldef ][:body ], warn
7676 )
7777
7878 return build_output (modelinfo, linenumbernode)
@@ -155,92 +155,84 @@ function build_model_info(input_expr)
155155end
156156
157157"""
158- generate_mainbody(mod, expr, args, warn)
158+ generate_mainbody(mod, expr, warn)
159159
160160Generate the body of the main evaluation function from expression `expr` and arguments
161161`args`.
162162
163163If `warn` is true, a warning is displayed if internal variables are used in the model
164164definition.
165165"""
166- generate_mainbody (mod, expr, args, warn) = generate_mainbody! (mod, Symbol[], expr, args , warn)
166+ generate_mainbody (mod, expr, warn) = generate_mainbody! (mod, Symbol[], expr, warn)
167167
168- generate_mainbody! (mod, found, x, args, warn) = x
169- function generate_mainbody! (mod, found, sym:: Symbol , args, warn)
168+ generate_mainbody! (mod, found, x, warn) = x
169+ function generate_mainbody! (mod, found, sym:: Symbol , warn)
170170 if warn && sym in INTERNALNAMES && sym ∉ found
171171 @warn " you are using the internal variable `$(sym) `"
172172 push! (found, sym)
173173 end
174174 return sym
175175end
176- function generate_mainbody! (mod, found, expr:: Expr , args, warn)
176+ function generate_mainbody! (mod, found, expr:: Expr , warn)
177177 # Do not touch interpolated expressions
178178 expr. head === :$ && return expr. args[1 ]
179179
180180 # If it's a macro, we expand it
181181 if Meta. isexpr (expr, :macrocall )
182- return generate_mainbody! (mod, found, macroexpand (mod, expr; recursive= true ), args, warn)
182+ return generate_mainbody! (mod, found, macroexpand (mod, expr; recursive= true ), warn)
183183 end
184184
185185 # Modify dotted tilde operators.
186186 args_dottilde = getargs_dottilde (expr)
187187 if args_dottilde != = nothing
188188 L, R = args_dottilde
189- return generate_dot_tilde (generate_mainbody! (mod, found, L, args, warn),
190- generate_mainbody! (mod, found, R, args, warn),
191- args) |> Base. remove_linenums!
189+ return generate_dot_tilde (
190+ generate_mainbody! (mod, found, L, warn),
191+ generate_mainbody! (mod, found, R, warn),
192+ ) |> Base. remove_linenums!
192193 end
193194
194195 # Modify tilde operators.
195196 args_tilde = getargs_tilde (expr)
196197 if args_tilde != = nothing
197198 L, R = args_tilde
198- return generate_tilde (generate_mainbody! (mod, found, L, args, warn),
199- generate_mainbody! (mod, found, R, args, warn),
200- args) |> Base. remove_linenums!
199+ return generate_tilde (
200+ generate_mainbody! (mod, found, L, warn),
201+ generate_mainbody! (mod, found, R, warn),
202+ ) |> Base. remove_linenums!
201203 end
202204
203- return Expr (expr. head, map (x -> generate_mainbody! (mod, found, x, args, warn), expr. args)... )
205+ return Expr (expr. head, map (x -> generate_mainbody! (mod, found, x, warn), expr. args)... )
204206end
205207
206208
207209
208210"""
209- generate_tilde(left, right, args )
211+ generate_tilde(left, right)
210212
211213Generate an `observe` expression for data variables and `assume` expression for parameter
212214variables.
213215"""
214- function generate_tilde (left, right, args )
216+ function generate_tilde (left, right)
215217 @gensym tmpright
216218 top = [:($ tmpright = $ right),
217219 :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
218220 || throw (ArgumentError ($ DISTMSG)))]
219221
220222 if left isa Symbol || left isa Expr
221- @gensym out vn inds
223+ @gensym out vn inds isassumption
222224 push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
223225
224- # It can only be an observation if the LHS is an argument of the model
225- if vsym (left) in args
226- @gensym isassumption
227- return quote
228- $ (top... )
229- $ isassumption = $ (DynamicPPL. isassumption (left))
230- if $ isassumption
231- $ left = $ (DynamicPPL. tilde_assume)(
232- _rng, _context, _sampler, $ tmpright, $ vn, $ inds, _varinfo)
233- else
234- $ (DynamicPPL. tilde_observe)(
235- _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
236- end
237- end
238- end
239-
240226 return quote
241227 $ (top... )
242- $ left = $ (DynamicPPL. tilde_assume)(_rng, _context, _sampler, $ tmpright, $ vn,
243- $ inds, _varinfo)
228+ $ isassumption = $ (DynamicPPL. isassumption (left))
229+ if $ isassumption
230+ $ left = $ (DynamicPPL. tilde_assume)(
231+ _rng, _context, _sampler, $ tmpright, $ vn, $ inds, _varinfo)
232+ else
233+ $ (DynamicPPL. tilde_observe)(
234+ _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
235+ end
244236 end
245237 end
246238
@@ -252,40 +244,30 @@ function generate_tilde(left, right, args)
252244end
253245
254246"""
255- generate_dot_tilde(left, right, args )
247+ generate_dot_tilde(left, right)
256248
257249Generate the expression that replaces `left .~ right` in the model body.
258250"""
259- function generate_dot_tilde (left, right, args )
251+ function generate_dot_tilde (left, right)
260252 @gensym tmpright
261253 top = [:($ tmpright = $ right),
262254 :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
263255 || throw (ArgumentError ($ DISTMSG)))]
264256
265257 if left isa Symbol || left isa Expr
266- @gensym out vn inds
258+ @gensym out vn inds isassumption
267259 push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
268260
269- # It can only be an observation if the LHS is an argument of the model
270- if vsym (left) in args
271- @gensym isassumption
272- return quote
273- $ (top... )
274- $ isassumption = $ (DynamicPPL. isassumption (left))
275- if $ isassumption
276- $ left .= $ (DynamicPPL. dot_tilde_assume)(
277- _rng, _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
278- else
279- $ (DynamicPPL. dot_tilde_observe)(
280- _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
281- end
282- end
283- end
284-
285261 return quote
286262 $ (top... )
287- $ left .= $ (DynamicPPL. dot_tilde_assume)(
288- _rng, _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
263+ $ isassumption = $ (DynamicPPL. isassumption (left)) || $ left === missing
264+ if $ isassumption
265+ $ left .= $ (DynamicPPL. dot_tilde_assume)(
266+ _rng, _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
267+ else
268+ $ (DynamicPPL. dot_tilde_observe)(
269+ _context, _sampler, $ tmpright, $ left, $ vn, $ inds, _varinfo)
270+ end
289271 end
290272 end
291273
0 commit comments