diff --git a/Project.toml b/Project.toml index ede0b23f5..7d025634a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.9" +version = "0.17.10" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/compiler.jl b/src/compiler.jl index e851dd310..ee763a3c6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -245,6 +245,9 @@ function build_model_info(input_expr) return modelinfo end + # Ensure that all arguments have a name, i.e., are of the form `name` or `name::T` + addargnames!(modeldef[:args]) + # Extract the positional and keyword arguments from the model definition. allargs = vcat(modeldef[:args], modeldef[:kwargs]) @@ -262,8 +265,7 @@ function build_model_info(input_expr) # Extract the names of the arguments. allargs_syms = map(allargs_exprs) do arg MacroTools.@match arg begin - (::Type{T_}) | (name_::Type{T_}) => T - name_::T_ => name + (name_::_) => name x_ => x end end diff --git a/src/utils.jl b/src/utils.jl index d8f9090d1..821eba38e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -74,6 +74,48 @@ macro addlogprob!(ex) end end +""" + addargnames!(args) + +Adds names to unnamed arguments in `args`. + +The names are generated with `gensym(:arg)` to avoid conflicts with other variable names. + +# Examples + +```jldoctest; filter = r"var\\"##arg#[0-9]+\\"" +julia> args = :(f(x::Int, y, ::Type{T}=Float64)).args[2:end] +3-element Vector{Any}: + :(x::Int) + :y + :($(Expr(:kw, :(::Type{T}), :Float64))) + +julia> DynamicPPL.addargnames!(args) + +julia> args +3-element Vector{Any}: + :(x::Int) + :y + :($(Expr(:kw, :(var"##arg#301"::Type{T}), :Float64))) +``` +""" +function addargnames!(args) + if isempty(args) + return nothing + end + + @inbounds for i in eachindex(args) + arg = args[i] + if MacroTools.@capture(arg, ::T_) + args[i] = Expr(:(::), gensym(:arg), T) + elseif MacroTools.@capture(arg, ::T_ = val_) + args[i] = Expr(:kw, Expr(:(::), gensym(:arg), T), val) + end + end + + return nothing +end + """ getargs_dottilde(x) diff --git a/test/compiler.jl b/test/compiler.jl index 8babe32ac..f59f013ac 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -602,4 +602,10 @@ end @test !DynamicPPL.hasmissing(Matrix{Real}) @test !DynamicPPL.hasmissing(Vector{Matrix{Float32}}) end + + @testset "issue #393: anonymous argument with type parameter" begin + @model f_393(::Val{ispredict}=Val(false)) where {ispredict} = ispredict ? 0 : 1 + @test f_393()() == 1 + @test f_393(Val(true))() == 0 + end end