diff --git a/src/arraydist.jl b/src/arraydist.jl index c358834..28e9e2b 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -3,7 +3,10 @@ const VectorOfUnivariate = Distributions.Product function arraydist(dists::AbstractVector{<:UnivariateDistribution}) - return Product(dists) + V = typeof(dists) + T = eltype(dists) + S = Distributions.value_support(T) + return Product{S,T,V}(dists) end struct MatrixOfUnivariate{ @@ -18,15 +21,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) return MatrixOfUnivariate(dists) end function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) - # return sum(((d, xi),) -> logpdf(d, xi), zip(dist.dists, x)) - # Broadcasting here breaks Tracker for some reason - return sum(map(logpdf, dist.dists, x)) + # Lazy broadcast to avoid allocations and use pairwise summation + return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x))) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) @@ -49,16 +51,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution}) end function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - return sum(((di, xi),) -> logpdf(di, xi), zip(dist.dists, eachcol(x))) + return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x)))) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(x -> logpdf(dist, x), x) + return map(Base.Fix1(logpdf, dist), x) end function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) - return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) + return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init) end diff --git a/src/filldist.jl b/src/filldist.jl index e4c6308..67b758a 100644 --- a/src/filldist.jl +++ b/src/filldist.jl @@ -30,21 +30,19 @@ end function _flat_logpdf(dist, x) if toflatten(dist) f, args = flatten(dist) - return sum(f.(args..., x)) + # Lazy broadcast to avoid allocations and use pairwise summation + return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x))) else - return sum(map(x) do x - logpdf(dist, x) - end) + return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x))) end end function _flat_logpdf_mat(dist, x) if toflatten(dist) f, args = flatten(dist) - return vec(sum(f.(args..., x), dims = 1)) + return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1)) else - temp = map(x -> logpdf(dist, x), x) - return vec(sum(temp, dims = 1)) + return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1)) end end diff --git a/test/Project.toml b/test/Project.toml index 33d7bb7..129e08b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesCore = "1" -ChainRulesTestUtils = "1" +ChainRulesTestUtils = "1.9.2" Combinatorics = "1.0.2" Distributions = "0.25.15" FiniteDifferences = "0.11.3, 0.12" diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index ef5ee18..d3ce2eb 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -408,9 +408,7 @@ # PoissonBinomial fails with Zygote # Matrix case does not work with Skellam: # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 - filldist_broken = if D <: Skellam - ((d.broken..., :Zygote, :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff)) - elseif D <: PoissonBinomial + filldist_broken = if D <: PoissonBinomial ((d.broken..., :Zygote), (d.broken..., :Zygote)) elseif D <: Chernoff # Zygote is not broken with `filldist` diff --git a/test/ad/utils.jl b/test/ad/utils.jl index ecf4a61..a4dcd6e 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -396,12 +396,16 @@ function testset_zygote(distspec, unpack_x_θ, args...; kwargs...) end end -function testset_zygote_broken(args...; kwargs...) +function testset_zygote_broken(distspec, args...; kwargs...) # don't show test errors - tests are known to be broken :) testset = suppress_stdout() do - testset_zygote(args...; kwargs...) + testset_zygote(distspec, args...; kwargs...) end + f = distspec.f + θ = distspec.θ + x = distspec.x + # change errors and fails to broken results, and count number of errors and fails efs = errors_to_broken!(testset)