@@ -21,15 +21,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
2121 return MatrixOfUnivariate (dists)
2222end
2323function Distributions. _logpdf (dist:: MatrixOfUnivariate , x:: AbstractMatrix{<:Real} )
24- # return sum(((d, xi),) -> logpdf(d, xi), zip(dist.dists, x))
25- # Broadcasting here breaks Tracker for some reason
26- return sum (map (logpdf, dist. dists, x))
24+ # Lazy broadcast to avoid allocations and use pairwise summation
25+ return sum (Broadcast. instantiate (Broadcast. broadcasted (logpdf, dist. dists, x)))
2726end
2827function Distributions. logpdf (dist:: MatrixOfUnivariate , x:: AbstractArray{<:AbstractMatrix{<:Real}} )
29- return map (x -> logpdf (dist, x ), x)
28+ return map (Base . Fix1 (logpdf, dist ), x)
3029end
3130function Distributions. logpdf (dist:: MatrixOfUnivariate , x:: AbstractArray{<:Matrix{<:Real}} )
32- return map (x -> logpdf (dist, x ), x)
31+ return map (Base . Fix1 (logpdf, dist ), x)
3332end
3433
3534function Distributions. rand (rng:: Random.AbstractRNG , dist:: MatrixOfUnivariate )
@@ -52,16 +51,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
5251end
5352
5453function Distributions. _logpdf (dist:: VectorOfMultivariate , x:: AbstractMatrix{<:Real} )
55- return sum (((di, xi),) -> logpdf (di, xi), zip ( dist. dists, eachcol (x)))
54+ return sum (Broadcast . instantiate (Broadcast . broadcasted ( logpdf, dist. dists, eachcol (x) )))
5655end
5756function Distributions. logpdf (dist:: VectorOfMultivariate , x:: AbstractArray{<:AbstractMatrix{<:Real}} )
58- return map (x -> logpdf (dist, x ), x)
57+ return map (Base . Fix1 (logpdf, dist ), x)
5958end
6059function Distributions. logpdf (dist:: VectorOfMultivariate , x:: AbstractArray{<:Matrix{<:Real}} )
61- return map (x -> logpdf (dist, x ), x)
60+ return map (Base . Fix1 (logpdf, dist ), x)
6261end
6362
6463function Distributions. rand (rng:: Random.AbstractRNG , dist:: VectorOfMultivariate )
6564 init = reshape (rand (rng, dist. dists[1 ]), :, 1 )
66- return mapreduce (i -> rand (rng, dist . dists[i] ), hcat, 2 : length (dist); init = init)
65+ return mapreduce (Base . Fix1 (rand, rng ), hcat, view (dist . dists, 2 : length (dist) ); init = init)
6766end
0 commit comments