Skip to content

Conversation

@torfjelde
Copy link
Member

Fixes #2160 and other related problems + simplifies the meanfield construction by making use of "recently" introduced Bijectors.Reshape.

@torfjelde torfjelde requested review from devmotion and yebai January 26, 2024 11:34
return Bijectors.logabsdetjac(f.b, reshape(x, f.size))
end
"""
wrap_in_vec_reshape(f, in_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this could be useful more generally outside of Turing? In that case, I think we could move it to Bijectors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking the same, but IMO we leave it here and move it we find other usecases for it?

num_params = length(varinfo[DynamicPPL.SampleFromPrior()])
# Use linked `varinfo` to determine the correct number of parameters.
varinfo_linked = DynamicPPL.link(varinfo, model)
num_params = length(varinfo_linked[:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh that reminds me of our getindex mess 😅 It feels a bit wasteful though to collect all parameters, doesn't it? It seems it would be sufficient to count how many there are - if there's no more efficient way maybe we should define length (or some function we own) in DynamicPPL to avoid getindex here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most certainly:)

In my work on VarNameVector I've implemented a proper length, but I didn't add one for VarInfo, etc. Might be worth it though. In the process I've encountered a bunch of things that should be simplified in VarInfo, so I'm debating whether to just make a separate PR with these things before making the move with VarNameVector.

But for now, given that this PR holds a fix for a breaking bug, I'll just add a TODO comment here for now.

μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
# `length(td.dist) != length(td)` if `td.transform` changes the dimensionality,
# so we need to use the length of the underlying distribution `td.dist` here.
μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A personal preference but IMO it's much clearer to write

Suggested change
μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end]
μ, ω = θ[1:length(td.dist)], θ[(length(td.dist) + 1):end]

Two other points:

Are we sure 1-based indexing is correct here? Based on the function signature 1:length(td.dist) seems dangerous, maybe rather use

Suggested change
μ, ω = θ[1:length(td.dist)], θ[length(td.dist) + 1:end]
μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end]

Moreover, do we need the copy or would a view be sufficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with that!

Would the view possibly interact badly with some AD backend or should it all just work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. Maybe let's just add a comment regarding views maybe?

@github-actions
Copy link
Contributor

github-actions bot commented Jan 26, 2024

Pull Request Test Coverage Report for Build 7668692500

  • -11 of 11 (0.0%) changed or added relevant lines in 1 file are covered.
  • 27 unchanged lines in 1 file lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/variational/advi.jl 0 11 0.0%
Files with Coverage Reduction New Missed Lines %
src/variational/advi.jl 27 0.0%
Totals Coverage Status
Change from base Build 7591630272: 0.0%
Covered Lines: 0
Relevant Lines: 1386

💛 - Coveralls

@codecov
Copy link

codecov bot commented Jan 26, 2024

Codecov Report

Attention: 11 lines in your changes are missing coverage. Please review.

Comparison is base (bb45a1f) 0.00% compared to head (e832862) 0.00%.

Files Patch % Lines
src/variational/advi.jl 0.00% 11 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##           master   #2162   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files          22      22           
  Lines        1396    1386   -10     
======================================
+ Misses       1396    1386   -10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Co-authored-by: David Widmann <[email protected]>
@torfjelde
Copy link
Member Author

I'm confused as to why codecov is complaining here 😕

@torfjelde
Copy link
Member Author

Added the TODO comments 👍

@torfjelde
Copy link
Member Author

Test failures here are unrelated (it's those MLE tests again), so I'll just merge this.

@torfjelde torfjelde merged commit 66fa9e2 into master Jan 28, 2024
@torfjelde torfjelde deleted the torfjelde/advi-fix branch January 28, 2024 19:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

input length mismatch error using ADVI

3 participants