Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.1"
version = "0.10.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
41 changes: 36 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,47 @@ Convert the `value` to the correct type for the `sampler` and the `vi` object.
function matchingvalue(sampler, vi, value)
T = typeof(value)
if hasmissing(T)
return convert(get_matching_type(sampler, vi, T), value)
_value = convert(get_matching_type(sampler, vi, T), value)
if _value === value
return deepcopy(_value)
else
return _value
end
else
return value
end
end
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)

"""
get_matching_type(spl, vi, ::Type{T}) where {T}
Get the specialized version of type `T` for sampler `spl`. For example,
if `T === Float64` and `spl::Hamiltonian`, the matching type is `eltype(vi[spl])`.
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}

Get the specialized version of type `T` for sampler `spl`.

For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is
`eltype(vi[spl])`.
"""
function get_matching_type end
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T
function get_matching_type(
spl::AbstractSampler,
vi,
::Type{<:Union{Missing, AbstractFloat}},
)
return Union{Missing, floatof(eltype(vi, spl))}
end
function get_matching_type(
spl::AbstractSampler,
vi,
::Type{<:AbstractFloat},
)
return floatof(eltype(vi, spl))
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(spl, vi, T), N}
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T
return Array{get_matching_type(spl, vi, T)}
end

floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T))
floatof(::Type) = Real # fallback if type inference failed
30 changes: 0 additions & 30 deletions test/Turing/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,36 +424,6 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
@eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space
end

floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T))
floatof(::Type) = Real # fallback if type inference failed

function get_matching_type(
spl::AbstractSampler,
vi,
::Type{T},
) where {T}
return T
end
function get_matching_type(
spl::AbstractSampler,
vi,
::Type{<:Union{Missing, AbstractFloat}},
)
return Union{Missing, floatof(eltype(vi, spl))}
end
function get_matching_type(
spl::AbstractSampler,
vi,
::Type{<:AbstractFloat},
)
return floatof(eltype(vi, spl))
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(spl, vi, T), N}
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T
return Array{get_matching_type(spl, vi, T)}
end
function get_matching_type(
spl::Sampler{<:Union{PG, SMC}},
vi,
Expand Down
18 changes: 18 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,24 @@ end
end
model = testmodel(rand(10))
@test all(z -> isapprox(z, 0; atol = 0.2), mean(model() for _ in 1:1000))

# test Turing#1464
@model function gdemo(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end
x = [1.0, missing]
VarInfo(gdemo(x))
@test ismissing(x[2])

# https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615
vi = VarInfo(gdemo(x))
@test haskey(vi.metadata, :x)
vi = VarInfo(gdemo(x))
@test haskey(vi.metadata, :x)
end
@testset "nested model" begin
function makemodel(p)
Expand Down