-
Notifications
You must be signed in to change notification settings - Fork 230
Fix dimensionality issues of ADVI #2162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| return Bijectors.logabsdetjac(f.b, reshape(x, f.size)) | ||
| end | ||
| """ | ||
| wrap_in_vec_reshape(f, in_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think this could be useful more generally outside of Turing? In that case, I think we could move it to Bijectors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was thinking the same, but IMO we leave it here and move it we find other usecases for it?
| num_params = length(varinfo[DynamicPPL.SampleFromPrior()]) | ||
| # Use linked `varinfo` to determine the correct number of parameters. | ||
| varinfo_linked = DynamicPPL.link(varinfo, model) | ||
| num_params = length(varinfo_linked[:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh that reminds me of our getindex mess 😅 It feels a bit wasteful though to collect all parameters, doesn't it? It seems it would be sufficient to count how many there are - if there's no more efficient way maybe we should define length (or some function we own) in DynamicPPL to avoid getindex here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most certainly:)
In my work on VarNameVector I've implemented a proper length, but I didn't add one for VarInfo, etc. Might be worth it though. In the process I've encountered a bunch of things that should be simplified in VarInfo, so I'm debating whether to just make a separate PR with these things before making the move with VarNameVector.
But for now, given that this PR holds a fix for a breaking bug, I'll just add a TODO comment here for now.
src/variational/advi.jl
Outdated
| μ, ω = θ[1:length(td)], θ[length(td) + 1:end] | ||
| # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, | ||
| # so we need to use the length of the underlying distribution `td.dist` here. | ||
| μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A personal preference but IMO it's much clearer to write
| μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end] | |
| μ, ω = θ[1:length(td.dist)], θ[(length(td.dist) + 1):end] |
Two other points:
Are we sure 1-based indexing is correct here? Based on the function signature 1:length(td.dist) seems dangerous, maybe rather use
| μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end] | |
| μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] |
Moreover, do we need the copy or would a view be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with that!
Would the view possibly interact badly with some AD backend or should it all just work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly. Maybe let's just add a comment regarding views maybe?
Pull Request Test Coverage Report for Build 7668692500
💛 - Coveralls |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #2162 +/- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 22 22
Lines 1396 1386 -10
======================================
+ Misses 1396 1386 -10 ☔ View full report in Codecov by Sentry. |
Co-authored-by: David Widmann <[email protected]>
|
I'm confused as to why codecov is complaining here 😕 |
|
Added the TODO comments 👍 |
|
Test failures here are unrelated (it's those MLE tests again), so I'll just merge this. |
Fixes #2160 and other related problems + simplifies the
meanfieldconstruction by making use of "recently" introducedBijectors.Reshape.