Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 22 additions & 44 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,17 @@
# TODO(torfjelde): Find a better solution.
struct Vec{N,B} <: Bijectors.Bijector
b::B
size::NTuple{N, Int}
end

Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size)

Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz)
Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n)

function Bijectors.with_logabsdet_jacobian(f::Vec, x)
return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x)
end

function Bijectors.transform(f::Vec, x::AbstractVector)
# Reshape into shape compatible with wrapped bijector and then `vec` again.
return vec(f.b(reshape(x, f.size)))
end

function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N
# Reshape into shape compatible with original (forward) bijector and then `vec` again.
return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size)))))
end

function Bijectors.transform(f::Vec, x::AbstractMatrix)
# At the moment we do batching for higher-than-1-dim spaces by simply using
# lists of inputs rather than `AbstractArray` with `N + 1` dimension.
cols = Iterators.Stateful(eachcol(x))
# Make `init` a matrix to ensure type-stability
init = reshape(f(first(cols)), :, 1)
return mapreduce(f, hcat, cols; init = init)
end

function Bijectors.logabsdetjac(f::Vec, x::AbstractVector)
return Bijectors.logabsdetjac(f.b, reshape(x, f.size))
end
# TODO: Move to Bijectors.jl if we find further use for this.
"""
wrap_in_vec_reshape(f, in_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this could be useful more generally outside of Turing? In that case, I think we could move it to Bijectors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking the same, but IMO we leave it here and move it we find other usecases for it?


function Bijectors.logabsdetjac(f::Vec, x::AbstractMatrix)
return map(eachcol(x)) do x_
Bijectors.logabsdetjac(f, x_)
end
Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces
a vector of length `prod(Bijectors.output(f, in_size))`.
"""
function wrap_in_vec_reshape(f, in_size)
vec_in_length = prod(in_size)
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
out_size = Bijectors.output_size(f, in_size)
vec_out_length = prod(out_size)
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
return reshape_outer ∘ f ∘ reshape_inner
end


Expand Down Expand Up @@ -83,7 +55,7 @@ function Bijectors.bijector(
if d isa Distributions.UnivariateDistribution
b
else
Vec(b, size(d))
wrap_in_vec_reshape(b, size(d))
end
end

Expand All @@ -106,7 +78,10 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model)
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model)
# Setup.
varinfo = DynamicPPL.VarInfo(model)
num_params = length(varinfo[DynamicPPL.SampleFromPrior()])
# Use linked `varinfo` to determine the correct number of parameters.
# TODO: Replace with `length` once this is implemented for `VarInfo`.
varinfo_linked = DynamicPPL.link(varinfo, model)
num_params = length(varinfo_linked[:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh that reminds me of our getindex mess 😅 It feels a bit wasteful though to collect all parameters, doesn't it? It seems it would be sufficient to count how many there are - if there's no more efficient way maybe we should define length (or some function we own) in DynamicPPL to avoid getindex here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most certainly:)

In my work on VarNameVector I've implemented a proper length, but I didn't add one for VarInfo, etc. Might be worth it though. In the process I've encountered a bunch of things that should be simplified in VarInfo, so I'm debating whether to just make a separate PR with these things before making the move with VarNameVector.

But for now, given that this PR holds a fix for a breaking bug, I'll just add a TODO comment here for now.


# initial params
μ = randn(rng, num_params)
Expand Down Expand Up @@ -134,7 +109,10 @@ function AdvancedVI.update(
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
θ::AbstractArray,
)
μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
# `length(td.dist) != length(td)` if `td.transform` changes the dimensionality,
# so we need to use the length of the underlying distribution `td.dist` here.
# TODO: Check if we can get away with `view` instead of `getindex` for all AD backends.
μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end]
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω))
end

Expand Down
5 changes: 5 additions & 0 deletions test/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,10 @@
x0_inv = inverse(b)(z0)
@test size(x0_inv) == size(x0)
@test all(x0 .≈ x0_inv)

# And regression for https://github.com/TuringLang/Turing.jl/issues/2160.
q = vi(m, ADVI(10, 1000))
x = rand(q, 1000)
@test mean(eachcol(x)) ≈ [0.5, 0.5] atol=0.1
end
end