-
Notifications
You must be signed in to change notification settings - Fork 230
Fix dimensionality issues of ADVI #2162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b5ed35
9200b68
56777c9
c60e45c
0b3af41
e832862
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,45 +1,17 @@ | ||
| # TODO(torfjelde): Find a better solution. | ||
| struct Vec{N,B} <: Bijectors.Bijector | ||
| b::B | ||
| size::NTuple{N, Int} | ||
| end | ||
|
|
||
| Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size) | ||
|
|
||
| Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz) | ||
| Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n) | ||
|
|
||
| function Bijectors.with_logabsdet_jacobian(f::Vec, x) | ||
| return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x) | ||
| end | ||
|
|
||
| function Bijectors.transform(f::Vec, x::AbstractVector) | ||
| # Reshape into shape compatible with wrapped bijector and then `vec` again. | ||
| return vec(f.b(reshape(x, f.size))) | ||
| end | ||
|
|
||
| function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N | ||
| # Reshape into shape compatible with original (forward) bijector and then `vec` again. | ||
| return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size))))) | ||
| end | ||
|
|
||
| function Bijectors.transform(f::Vec, x::AbstractMatrix) | ||
| # At the moment we do batching for higher-than-1-dim spaces by simply using | ||
| # lists of inputs rather than `AbstractArray` with `N + 1` dimension. | ||
| cols = Iterators.Stateful(eachcol(x)) | ||
| # Make `init` a matrix to ensure type-stability | ||
| init = reshape(f(first(cols)), :, 1) | ||
| return mapreduce(f, hcat, cols; init = init) | ||
| end | ||
|
|
||
| function Bijectors.logabsdetjac(f::Vec, x::AbstractVector) | ||
| return Bijectors.logabsdetjac(f.b, reshape(x, f.size)) | ||
| end | ||
| # TODO: Move to Bijectors.jl if we find further use for this. | ||
| """ | ||
| wrap_in_vec_reshape(f, in_size) | ||
|
|
||
| function Bijectors.logabsdetjac(f::Vec, x::AbstractMatrix) | ||
| return map(eachcol(x)) do x_ | ||
| Bijectors.logabsdetjac(f, x_) | ||
| end | ||
| Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces | ||
| a vector of length `prod(Bijectors.output(f, in_size))`. | ||
| """ | ||
| function wrap_in_vec_reshape(f, in_size) | ||
| vec_in_length = prod(in_size) | ||
| reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) | ||
| out_size = Bijectors.output_size(f, in_size) | ||
| vec_out_length = prod(out_size) | ||
| reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) | ||
| return reshape_outer ∘ f ∘ reshape_inner | ||
| end | ||
|
|
||
|
|
||
|
|
@@ -83,7 +55,7 @@ function Bijectors.bijector( | |
| if d isa Distributions.UnivariateDistribution | ||
| b | ||
| else | ||
| Vec(b, size(d)) | ||
| wrap_in_vec_reshape(b, size(d)) | ||
| end | ||
| end | ||
|
|
||
|
|
@@ -106,7 +78,10 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model) | |
| function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) | ||
| # Setup. | ||
| varinfo = DynamicPPL.VarInfo(model) | ||
| num_params = length(varinfo[DynamicPPL.SampleFromPrior()]) | ||
| # Use linked `varinfo` to determine the correct number of parameters. | ||
| # TODO: Replace with `length` once this is implemented for `VarInfo`. | ||
| varinfo_linked = DynamicPPL.link(varinfo, model) | ||
| num_params = length(varinfo_linked[:]) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooh that reminds me of our
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most certainly:) In my work on But for now, given that this PR holds a fix for a breaking bug, I'll just add a TODO comment here for now. |
||
|
|
||
| # initial params | ||
| μ = randn(rng, num_params) | ||
|
|
@@ -134,7 +109,10 @@ function AdvancedVI.update( | |
| td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, | ||
| θ::AbstractArray, | ||
| ) | ||
| μ, ω = θ[1:length(td)], θ[length(td) + 1:end] | ||
| # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, | ||
| # so we need to use the length of the underlying distribution `td.dist` here. | ||
| # TODO: Check if we can get away with `view` instead of `getindex` for all AD backends. | ||
| μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] | ||
| return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) | ||
| end | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think this could be useful more generally outside of Turing? In that case, I think we could move it to Bijectors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was thinking the same, but IMO we leave it here and move it we find other usecases for it?