From 3b5ed358cb76e01d7965fb2b2f454f770751e1d2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Jan 2024 11:29:17 +0000 Subject: [PATCH 1/6] replace `Vec` bijector with a more general composition with `Bijectors.Reshape` --- src/variational/advi.jl | 55 ++++++++++------------------------------- 1 file changed, 13 insertions(+), 42 deletions(-) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index 608590f5dc..c12ec00403 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -1,45 +1,16 @@ -# TODO(torfjelde): Find a better solution. -struct Vec{N,B} <: Bijectors.Bijector - b::B - size::NTuple{N, Int} -end - -Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size) - -Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz) -Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n) - -function Bijectors.with_logabsdet_jacobian(f::Vec, x) - return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x) -end - -function Bijectors.transform(f::Vec, x::AbstractVector) - # Reshape into shape compatible with wrapped bijector and then `vec` again. - return vec(f.b(reshape(x, f.size))) -end - -function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N - # Reshape into shape compatible with original (forward) bijector and then `vec` again. - return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size))))) -end - -function Bijectors.transform(f::Vec, x::AbstractMatrix) - # At the moment we do batching for higher-than-1-dim spaces by simply using - # lists of inputs rather than `AbstractArray` with `N + 1` dimension. - cols = Iterators.Stateful(eachcol(x)) - # Make `init` a matrix to ensure type-stability - init = reshape(f(first(cols)), :, 1) - return mapreduce(f, hcat, cols; init = init) -end - -function Bijectors.logabsdetjac(f::Vec, x::AbstractVector) - return Bijectors.logabsdetjac(f.b, reshape(x, f.size)) -end +""" + wrap_in_vec_reshape(f, in_size) -function Bijectors.logabsdetjac(f::Vec, x::AbstractMatrix) - return map(eachcol(x)) do x_ - Bijectors.logabsdetjac(f, x_) - end +Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces +a vector of length `prod(Bijectors.output(f, in_size))`. +""" +function wrap_in_vec_reshape(f, in_size) + vec_in_length = prod(in_size) + reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) + out_size = Bijectors.output_size(f, in_size) + vec_out_length = prod(out_size) + reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) + return reshape_outer ∘ f ∘ reshape_inner end @@ -83,7 +54,7 @@ function Bijectors.bijector( if d isa Distributions.UnivariateDistribution b else - Vec(b, size(d)) + wrap_in_vec_reshape(b, size(d)) end end From 9200b68553fb640a101234145a07a68881d9e414 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Jan 2024 11:29:53 +0000 Subject: [PATCH 2/6] correctly use linked varinfo to determine number of parameters --- src/variational/advi.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index c12ec00403..c13a88793a 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -77,7 +77,9 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model) function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) # Setup. varinfo = DynamicPPL.VarInfo(model) - num_params = length(varinfo[DynamicPPL.SampleFromPrior()]) + # Use linked `varinfo` to determine the correct number of parameters. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) # initial params μ = randn(rng, num_params) @@ -105,7 +107,9 @@ function AdvancedVI.update( td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, θ::AbstractArray, ) - μ, ω = θ[1:length(td)], θ[length(td) + 1:end] + # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, + # so we need to use the length of the underlying distribution `td.dist` here. + μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end] return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) end From 56777c94c938fb51852ee085c6886ce047fd2660 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Jan 2024 11:32:24 +0000 Subject: [PATCH 3/6] added regression test --- test/variational/advi.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index ecca265453..17cbc7ea22 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -64,5 +64,10 @@ x0_inv = inverse(b)(z0) @test size(x0_inv) == size(x0) @test all(x0 .≈ x0_inv) + + # And regression for https://github.com/TuringLang/Turing.jl/issues/2160. + q = vi(m, ADVI(10, 1000)) + x = rand(q, 1000) + @test mean(eachcol(x)) ≈ [0.5, 0.5] atol=0.1 end end From c60e45cceed523e2d66af463886b8a56dd5c2626 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Jan 2024 13:28:10 +0000 Subject: [PATCH 4/6] Update src/variational/advi.jl Co-authored-by: David Widmann --- src/variational/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index c13a88793a..508b439e66 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -109,7 +109,7 @@ function AdvancedVI.update( ) # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, # so we need to use the length of the underlying distribution `td.dist` here. - μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end] + μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) end From 0b3af41f717c6012f380194ef034b87a33cc793b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Jan 2024 13:46:22 +0000 Subject: [PATCH 5/6] added comment on desired usage of `view` --- src/variational/advi.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index 508b439e66..71422a18b0 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -109,6 +109,7 @@ function AdvancedVI.update( ) # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, # so we need to use the length of the underlying distribution `td.dist` here. + # TODO: Check if we can get away with `view` instead of `getindex` for all AD backends. μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) end From e8328622355fcc9ce33da8649dc34c6bd1b70b6b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Jan 2024 13:47:25 +0000 Subject: [PATCH 6/6] more TODO comments for future ref --- src/variational/advi.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index 71422a18b0..cf2d4034a0 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -1,3 +1,4 @@ +# TODO: Move to Bijectors.jl if we find further use for this. """ wrap_in_vec_reshape(f, in_size) @@ -78,6 +79,7 @@ function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) # Setup. varinfo = DynamicPPL.VarInfo(model) # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. varinfo_linked = DynamicPPL.link(varinfo, model) num_params = length(varinfo_linked[:])