1- # TODO (torfjelde): Find a better solution.
2- struct Vec{N,B} <: Bijectors.Bijector
3- b:: B
4- size:: NTuple{N, Int}
5- end
6-
7- Bijectors. inverse (f:: Vec ) = Vec (Bijectors. inverse (f. b), f. size)
8-
9- Bijectors. output_length (f:: Vec , sz) = Bijectors. output_length (f. b, sz)
10- Bijectors. output_length (f:: Vec , n:: Int ) = Bijectors. output_length (f. b, n)
11-
12- function Bijectors. with_logabsdet_jacobian (f:: Vec , x)
13- return Bijectors. transform (f, x), Bijectors. logabsdetjac (f, x)
14- end
15-
16- function Bijectors. transform (f:: Vec , x:: AbstractVector )
17- # Reshape into shape compatible with wrapped bijector and then `vec` again.
18- return vec (f. b (reshape (x, f. size)))
19- end
20-
21- function Bijectors. transform (f:: Vec{N,<:Bijectors.Inverse} , x:: AbstractVector ) where N
22- # Reshape into shape compatible with original (forward) bijector and then `vec` again.
23- return vec (f. b (reshape (x, Bijectors. output_length (f. b. orig, prod (f. size)))))
24- end
25-
26- function Bijectors. transform (f:: Vec , x:: AbstractMatrix )
27- # At the moment we do batching for higher-than-1-dim spaces by simply using
28- # lists of inputs rather than `AbstractArray` with `N + 1` dimension.
29- cols = Iterators. Stateful (eachcol (x))
30- # Make `init` a matrix to ensure type-stability
31- init = reshape (f (first (cols)), :, 1 )
32- return mapreduce (f, hcat, cols; init = init)
33- end
34-
35- function Bijectors. logabsdetjac (f:: Vec , x:: AbstractVector )
36- return Bijectors. logabsdetjac (f. b, reshape (x, f. size))
37- end
1+ # TODO : Move to Bijectors.jl if we find further use for this.
2+ """
3+ wrap_in_vec_reshape(f, in_size)
384
39- function Bijectors. logabsdetjac (f:: Vec , x:: AbstractMatrix )
40- return map (eachcol (x)) do x_
41- Bijectors. logabsdetjac (f, x_)
42- end
5+ Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces
6+ a vector of length `prod(Bijectors.output(f, in_size))`.
7+ """
8+ function wrap_in_vec_reshape (f, in_size)
9+ vec_in_length = prod (in_size)
10+ reshape_inner = Bijectors. Reshape ((vec_in_length,), in_size)
11+ out_size = Bijectors. output_size (f, in_size)
12+ vec_out_length = prod (out_size)
13+ reshape_outer = Bijectors. Reshape (out_size, (vec_out_length,))
14+ return reshape_outer ∘ f ∘ reshape_inner
4315end
4416
4517
@@ -83,7 +55,7 @@ function Bijectors.bijector(
8355 if d isa Distributions. UnivariateDistribution
8456 b
8557 else
86- Vec (b, size (d))
58+ wrap_in_vec_reshape (b, size (d))
8759 end
8860 end
8961
@@ -106,7 +78,10 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model)
10678function meanfield (rng:: Random.AbstractRNG , model:: DynamicPPL.Model )
10779 # Setup.
10880 varinfo = DynamicPPL. VarInfo (model)
109- num_params = length (varinfo[DynamicPPL. SampleFromPrior ()])
81+ # Use linked `varinfo` to determine the correct number of parameters.
82+ # TODO : Replace with `length` once this is implemented for `VarInfo`.
83+ varinfo_linked = DynamicPPL. link (varinfo, model)
84+ num_params = length (varinfo_linked[:])
11085
11186 # initial params
11287 μ = randn (rng, num_params)
@@ -134,7 +109,10 @@ function AdvancedVI.update(
134109 td:: Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal} ,
135110 θ:: AbstractArray ,
136111)
137- μ, ω = θ[1 : length (td)], θ[length (td) + 1 : end ]
112+ # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality,
113+ # so we need to use the length of the underlying distribution `td.dist` here.
114+ # TODO : Check if we can get away with `view` instead of `getindex` for all AD backends.
115+ μ, ω = θ[begin : (begin + length (td. dist) - 1 )], θ[(begin + length (td. dist)): end ]
138116 return AdvancedVI. update (td, μ, StatsFuns. softplus .(ω))
139117end
140118
0 commit comments