Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.21.5"
version = "0.21.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
10 changes: 8 additions & 2 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ $(FIELDS)
```jldoctest
julia> using Distributions

julia> using DynamicPPL: LogDensityFunction
julia> using DynamicPPL: LogDensityFunction, contextualize

julia> @model function demo(x)
m ~ Normal()
Expand All @@ -36,6 +36,12 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary.

julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> # This also respects the context in `model`.
f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model));

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
```
"""
struct LogDensityFunction{V,M,C}
Expand All @@ -60,7 +66,7 @@ end
function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=DefaultContext(),
context::AbstractContext=model.context,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a test for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

)
return LogDensityFunction(varinfo, model, context)
end
Expand Down
22 changes: 2 additions & 20 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,8 @@ using Random: Random
using Bijectors: Bijectors
using Setfield: Setfield

"""
varname_leaves(vn::VarName, val)

Return iterator over all varnames that are represented by `vn` on `val`,
e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`.
"""
varname_leaves(vn::VarName, val::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(
VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]
) for I in CartesianIndices(val)
)
end
# For backwards compat.
using DynamicPPL: varname_leaves

"""
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
Expand Down
46 changes: 46 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,49 @@ infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_

# No need + causes issues for some AD backends, e.g. Zygote.
ChainRulesCore.@non_differentiable infer_nested_eltype(x)

"""
varname_leaves(vn::VarName, val)

Return an iterator over all varnames that are represented by `vn` on `val`.

# Examples
```jldoctest
julia> using DynamicPPL: varname_leaves

julia> foreach(println, varname_leaves(@varname(x), rand(2)))
x[1]
x[2]

julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
x[1:2][1]
x[1:2][2]

julia> x = (y = 1, z = [[2.0], [3.0]]);

julia> foreach(println, varname_leaves(@varname(x), x))
x.y
x.z[1][1]
x.z[2][1]
```
"""
varname_leaves(vn::VarName, ::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
lens = Setfield.PropertyLens{sym}()
varname_leaves(vn ∘ lens, get(val, lens))
end
return Iterators.flatten(iter)
end