Skip to content

Commit 66fa9e2

Browse files
torfjeldedevmotion
andauthored
Fix dimensionality issues of ADVI (#2162)
* replace `Vec` bijector with a more general composition with `Bijectors.Reshape` * correctly use linked varinfo to determine number of parameters * added regression test * Update src/variational/advi.jl Co-authored-by: David Widmann <[email protected]> * added comment on desired usage of `view` * more TODO comments for future ref --------- Co-authored-by: David Widmann <[email protected]>
1 parent bb45a1f commit 66fa9e2

File tree

2 files changed

+27
-44
lines changed

2 files changed

+27
-44
lines changed

src/variational/advi.jl

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,17 @@
1-
# TODO(torfjelde): Find a better solution.
2-
struct Vec{N,B} <: Bijectors.Bijector
3-
b::B
4-
size::NTuple{N, Int}
5-
end
6-
7-
Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size)
8-
9-
Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz)
10-
Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n)
11-
12-
function Bijectors.with_logabsdet_jacobian(f::Vec, x)
13-
return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x)
14-
end
15-
16-
function Bijectors.transform(f::Vec, x::AbstractVector)
17-
# Reshape into shape compatible with wrapped bijector and then `vec` again.
18-
return vec(f.b(reshape(x, f.size)))
19-
end
20-
21-
function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N
22-
# Reshape into shape compatible with original (forward) bijector and then `vec` again.
23-
return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size)))))
24-
end
25-
26-
function Bijectors.transform(f::Vec, x::AbstractMatrix)
27-
# At the moment we do batching for higher-than-1-dim spaces by simply using
28-
# lists of inputs rather than `AbstractArray` with `N + 1` dimension.
29-
cols = Iterators.Stateful(eachcol(x))
30-
# Make `init` a matrix to ensure type-stability
31-
init = reshape(f(first(cols)), :, 1)
32-
return mapreduce(f, hcat, cols; init = init)
33-
end
34-
35-
function Bijectors.logabsdetjac(f::Vec, x::AbstractVector)
36-
return Bijectors.logabsdetjac(f.b, reshape(x, f.size))
37-
end
1+
# TODO: Move to Bijectors.jl if we find further use for this.
2+
"""
3+
wrap_in_vec_reshape(f, in_size)
384
39-
function Bijectors.logabsdetjac(f::Vec, x::AbstractMatrix)
40-
return map(eachcol(x)) do x_
41-
Bijectors.logabsdetjac(f, x_)
42-
end
5+
Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces
6+
a vector of length `prod(Bijectors.output(f, in_size))`.
7+
"""
8+
function wrap_in_vec_reshape(f, in_size)
9+
vec_in_length = prod(in_size)
10+
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
11+
out_size = Bijectors.output_size(f, in_size)
12+
vec_out_length = prod(out_size)
13+
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
14+
return reshape_outer f reshape_inner
4315
end
4416

4517

@@ -83,7 +55,7 @@ function Bijectors.bijector(
8355
if d isa Distributions.UnivariateDistribution
8456
b
8557
else
86-
Vec(b, size(d))
58+
wrap_in_vec_reshape(b, size(d))
8759
end
8860
end
8961

@@ -106,7 +78,10 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model)
10678
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model)
10779
# Setup.
10880
varinfo = DynamicPPL.VarInfo(model)
109-
num_params = length(varinfo[DynamicPPL.SampleFromPrior()])
81+
# Use linked `varinfo` to determine the correct number of parameters.
82+
# TODO: Replace with `length` once this is implemented for `VarInfo`.
83+
varinfo_linked = DynamicPPL.link(varinfo, model)
84+
num_params = length(varinfo_linked[:])
11085

11186
# initial params
11287
μ = randn(rng, num_params)
@@ -134,7 +109,10 @@ function AdvancedVI.update(
134109
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
135110
θ::AbstractArray,
136111
)
137-
μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
112+
# `length(td.dist) != length(td)` if `td.transform` changes the dimensionality,
113+
# so we need to use the length of the underlying distribution `td.dist` here.
114+
# TODO: Check if we can get away with `view` instead of `getindex` for all AD backends.
115+
μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end]
138116
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω))
139117
end
140118

test/variational/advi.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,10 @@
6464
x0_inv = inverse(b)(z0)
6565
@test size(x0_inv) == size(x0)
6666
@test all(x0 .≈ x0_inv)
67+
68+
# And regression for https://github.com/TuringLang/Turing.jl/issues/2160.
69+
q = vi(m, ADVI(10, 1000))
70+
x = rand(q, 1000)
71+
@test mean(eachcol(x)) [0.5, 0.5] atol=0.1
6772
end
6873
end

0 commit comments

Comments
 (0)