Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return Product(dists)
V = typeof(dists)
T = eltype(dists)
S = Distributions.value_support(T)
return Product{S,T,V}(dists)
end

struct MatrixOfUnivariate{
Expand All @@ -18,15 +21,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# return sum(((d, xi),) -> logpdf(d, xi), zip(dist.dists, x))
# Broadcasting here breaks Tracker for some reason
return sum(map(logpdf, dist.dists, x))
# Lazy broadcast to avoid allocations and use pairwise summation
return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x)))
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
Expand All @@ -49,16 +51,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
end

function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
return sum(((di, xi),) -> logpdf(di, xi), zip(dist.dists, eachcol(x)))
return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x))))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
return map(Base.Fix1(logpdf, dist), x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init)
end
12 changes: 5 additions & 7 deletions src/filldist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,19 @@ end
function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
# Lazy broadcast to avoid allocations and use pairwise summation
return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x)))
else
return sum(map(x) do x
logpdf(dist, x)
end)
return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x)))
end
end

function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1))
return vec(sum(xi -> f(args..., xi), x; dims=1))

else
temp = map(x -> logpdf(dist, x), x)
return vec(sum(temp, dims = 1))
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1))
return vec(sum(Base.Fix1(logpdf, dist), x; dims = 1))

end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
ChainRulesTestUtils = "1.9.2"
Combinatorics = "1.0.2"
Distributions = "0.25.15"
FiniteDifferences = "0.11.3, 0.12"
Expand Down
4 changes: 1 addition & 3 deletions test/ad/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@
# PoissonBinomial fails with Zygote
# Matrix case does not work with Skellam:
# https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493
filldist_broken = if D <: Skellam
((d.broken..., :Zygote, :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff))
elseif D <: PoissonBinomial
filldist_broken = if D <: PoissonBinomial
((d.broken..., :Zygote), (d.broken..., :Zygote))
elseif D <: Chernoff
# Zygote is not broken with `filldist`
Expand Down
8 changes: 6 additions & 2 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,16 @@ function testset_zygote(distspec, unpack_x_θ, args...; kwargs...)
end
end

function testset_zygote_broken(args...; kwargs...)
function testset_zygote_broken(distspec, args...; kwargs...)
# don't show test errors - tests are known to be broken :)
testset = suppress_stdout() do
testset_zygote(args...; kwargs...)
testset_zygote(distspec, args...; kwargs...)
end

f = distspec.f
θ = distspec.θ
x = distspec.x

# change errors and fails to broken results, and count number of errors and fails
efs = errors_to_broken!(testset)

Expand Down