diff --git a/Project.toml b/Project.toml index 76dede4..12ee3f8 100644 --- a/Project.toml +++ b/Project.toml @@ -20,13 +20,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] DistributionsADForwardDiffExt = "ForwardDiff" -DistributionsADLazyArraysExt = "LazyArrays" DistributionsADReverseDiffExt = "ReverseDiff" DistributionsADTrackerExt = "Tracker" @@ -38,7 +36,6 @@ Compat = "3.6, 4" Distributions = "0.25.41" FillArrays = "1.4.1" ForwardDiff = "0.10.12, 1" -LazyArrays = "1, 2" LinearAlgebra = "<0.0.1, 1" PDMats = "0.9, 0.10, 0.11" Random = "<0.0.1, 1" @@ -53,6 +50,5 @@ julia = "1.6.5" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/docs/src/api.md b/docs/src/api.md index 0f80f43..9c30cd8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,8 +1,3 @@ # API -## Functions - -```@docs -filldist -arraydist -``` +This package provides automatic differentiation support for distributions in Distributions.jl. diff --git a/ext/DistributionsADLazyArraysExt.jl b/ext/DistributionsADLazyArraysExt.jl deleted file mode 100644 index a030ca6..0000000 --- a/ext/DistributionsADLazyArraysExt.jl +++ /dev/null @@ -1,52 +0,0 @@ -module DistributionsADLazyArraysExt - -if isdefined(Base, :get_extension) - using DistributionsAD - using LazyArrays - using DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate, MatrixOfUnivariate - using LazyArrays: BroadcastArray, BroadcastVector, LazyArray -else - using ..DistributionsAD - using ..LazyArrays - using ..DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate, MatrixOfUnivariate - using ..LazyArrays: BroadcastArray, BroadcastVector, LazyArray -end - -const LazyVectorOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastVector{T}, -} = VectorOfUnivariate{S,T,Tdists} - -function Distributions._logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractVector{<:Real}, -) - return sum(copy(Distributions.logpdf.(dist.v, x))) -end - -function Distributions.logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractMatrix{<:Real}, -) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return vec(sum(copy(Distributions.logpdf.(dists, x)), dims = 1)) -end - -const LazyMatrixOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastArray{T,2}, -} = MatrixOfUnivariate{S,T,Tdists} - -function Distributions._logpdf( - dist::LazyMatrixOfUnivariate, - x::AbstractMatrix{<:Real}, -) - return sum(copy(Distributions.logpdf.(dist.dists, x))) -end - -DistributionsAD.lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...)) - -end # module diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index c6203f6..335303f 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -19,13 +19,9 @@ export TuringScalMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringWishart, - TuringInverseWishart, - arraydist, - filldist + TuringInverseWishart include("common.jl") -include("arraydist.jl") -include("filldist.jl") include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") @@ -33,25 +29,14 @@ include("flatten.jl") include("zygote.jl") -# Empty definition, function requires the LazyArrays extension -function lazyarray end -export lazyarray - if !isdefined(Base, :get_extension) using Requires end function __init__() - # Better error message if users forget to load LazyArrays - Base.Experimental.register_error_hint(MethodError) do io, exc, arg_types, kwargs - if exc.f === lazyarray - print(io, "\\nDid you forget to load LazyArrays?") - end - end @static if !isdefined(Base, :get_extension) @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("../ext/DistributionsADForwardDiffExt.jl") @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/DistributionsADReverseDiffExt.jl") @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/DistributionsADTrackerExt.jl") - @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" include("../ext/DistributionsADLazyArraysExt.jl") end end diff --git a/src/arraydist.jl b/src/arraydist.jl deleted file mode 100644 index 062bab0..0000000 --- a/src/arraydist.jl +++ /dev/null @@ -1,108 +0,0 @@ -""" - arraydist(dists::AbstractArray{<:Distribution}) - -Create a product distribution from an array of sub-distributions. Each element -of `dists` should have the same size. If the size of each element is `(d1, d2, -...)`, and `size(dists)` is `(n1, n2, ...)`, then the resulting distribution -will have size `(d1, d2, ..., n1, n2, ...)`. - -The default behaviour is to directly use -[`Distributions.product_distribution`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.product_distribution), -although this can sometimes be specialised. - -# Examples - -```jldoctest; setup=:(using Distributions, Random) -julia> d1 = arraydist([Normal(0, 1), Normal(10, 1)]) -Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}(v=Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=10.0, σ=1.0)]) - -julia> size(d1) -(2,) - -julia> Random.seed!(42); rand(d1) -2-element Vector{Float64}: - 0.7883556016042917 - 9.1201414040456 - -julia> d2 = arraydist([Normal(0, 1) Normal(5, 1); Normal(10, 1) Normal(15, 1)]) -DistributionsAD.MatrixOfUnivariate{Continuous, Normal{Float64}, Matrix{Normal{Float64}}}( -dists: Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0) Normal{Float64}(μ=5.0, σ=1.0); Normal{Float64}(μ=10.0, σ=1.0) Normal{Float64}(μ=15.0, σ=1.0)] -) - -julia> size(d2) -(2, 2) - -julia> Random.seed!(42); rand(d2) -2×2 Matrix{Float64}: - 0.788356 4.12621 - 9.12014 14.2667 -``` -""" -arraydist(dists::AbstractArray{<:Distribution}) = product_distribution(dists) - -# Univariate - -const VectorOfUnivariate = Distributions.Product - -function arraydist(dists::AbstractVector{<:UnivariateDistribution}) - V = typeof(dists) - T = eltype(dists) - S = Distributions.value_support(T) - return Product{S,T,V}(dists) -end - -struct MatrixOfUnivariate{ - S <: ValueSupport, - Tdist <: UnivariateDistribution{S}, - Tdists <: AbstractMatrix{Tdist}, -} <: MatrixDistribution{S} - dists::Tdists -end -Base.size(dist::MatrixOfUnivariate) = size(dist.dists) -function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) - return MatrixOfUnivariate(dists) -end -function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) - # Lazy broadcast to avoid allocations and use pairwise summation - return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x))) -end -function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end -function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end - -function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) - return rand.(Ref(rng), dist.dists) -end - -# Multivariate - -struct VectorOfMultivariate{ - S <: ValueSupport, - Tdist <: MultivariateDistribution{S}, - Tdists <: AbstractVector{Tdist}, -} <: MatrixDistribution{S} - dists::Tdists -end -Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) -Base.length(dist::VectorOfMultivariate) = length(dist.dists) -function arraydist(dists::AbstractVector{<:MultivariateDistribution}) - return VectorOfMultivariate(dists) -end - -function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x)))) -end -function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end -function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end - -function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) - init = reshape(rand(rng, dist.dists[1]), :, 1) - return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init) -end diff --git a/src/filldist.jl b/src/filldist.jl deleted file mode 100644 index d958361..0000000 --- a/src/filldist.jl +++ /dev/null @@ -1,123 +0,0 @@ -""" - filldist(d::Distribution, ns...) - -Create a product distribution from a single distribution and a list of -dimension sizes. If `size(d)` is `(d1, d2, ...)` and `ns` is `(n1, n2, ...)`, -then the resulting distribution will have size `(d1, d2, ..., n1, n2, ...)`. - -The default behaviour is to use -[`Distributions.product_distribution`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.product_distribution), -with `FillArrays.Fill` supplied as the array argument. However, this behaviour -is specialised in some instances, such as the one shown below. - -When sampling from the resulting distribution, the output will be an array where -each element is sampled from the original distribution `d`. - -# Examples - -```jldoctest; setup=:(using Distributions, Random) -julia> d = filldist(Normal(0, 1), 4, 5); - -julia> size(d) -(4, 5) - -julia> rand(d) isa Matrix{Float64} -true -``` -""" -filldist(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...)) - -# Univariate - -# TODO: Do we even need these? Probably should benchmark to be sure. -const FillVectorOfUnivariate{ - S <: ValueSupport, - T <: UnivariateDistribution{S}, - Tdists <: Fill{T, 1}, -} = VectorOfUnivariate{S, T, Tdists} - -function filldist(dist::UnivariateDistribution, N::Int) - return product_distribution(Fill(dist, N)) -end -filldist(d::Normal, N::Int) = TuringMvNormal(fill(d.μ, N), d.σ) - -function Distributions._logpdf( - dist::FillVectorOfUnivariate, - x::AbstractVector{<:Real}, -) - return _flat_logpdf(dist.v.value, x) -end - -function Distributions.logpdf( - dist::FillVectorOfUnivariate, - x::AbstractMatrix{<:Real}, -) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return _flat_logpdf_mat(dist.v.value, x) -end - -function _flat_logpdf(dist, x) - if toflatten(dist) - f, args = flatten(dist) - # Lazy broadcast to avoid allocations and use pairwise summation - return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x))) - else - return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x))) - end -end - -function _flat_logpdf_mat(dist, x) - if toflatten(dist) - f, args = flatten(dist) - return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1)) - else - return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1)) - end -end - -function Distributions.rand(rng::Random.AbstractRNG, d::FillVectorOfUnivariate) - return rand(rng, d.v.value, length(d)) -end -function Distributions.rand(rng::Random.AbstractRNG, d::FillVectorOfUnivariate, n::Int) - return rand(rng, d.v.value, length(d), n) -end - -const FillMatrixOfUnivariate{ - S <: ValueSupport, - T <: UnivariateDistribution{S}, - Tdists <: Fill{T, 2}, -} = MatrixOfUnivariate{S, T, Tdists} - -function filldist(dist::UnivariateDistribution, N1::Int, N2::Int) - return MatrixOfUnivariate(Fill(dist, N1, N2)) -end -function Distributions._logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real}) - # return loglikelihood(dist.dists.value, x) - return _flat_logpdf(dist.dists.value, x) -end -function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate) - return rand(rng, dist.dists.value, length.(dist.dists.axes)...,) -end - -# Multivariate - -const FillVectorOfMultivariate{ - S <: ValueSupport, - T <: MultivariateDistribution{S}, - Tdists <: Fill{T, 1}, -} = VectorOfMultivariate{S, T, Tdists} - -function filldist(dist::MultivariateDistribution, N::Int) - return VectorOfMultivariate(Fill(dist, N)) -end -function Distributions._logpdf( - dist::FillVectorOfMultivariate, - x::AbstractMatrix{<:Real}, -) - return loglikelihood(dist.dists.value, x) -end - -function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) - return rand(rng, dist.dists.value, length.(dist.dists.axes)...,) -end diff --git a/src/zygote.jl b/src/zygote.jl index 86598b6..ff6e3a6 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -6,14 +6,6 @@ ZygoteRules.@adjoint function Distributions._logpdf(d::Product, x::AbstractVecto sum(map(logpdf, d.v, x)) end end -ZygoteRules.@adjoint function Distributions._logpdf( - d::FillVectorOfUnivariate, - x::AbstractVector{<:Real}, -) - return ZygoteRules.pullback(d, x) do d, x - _flat_logpdf(d.v.value, x) - end -end # Loglikelihood of multi- and matrixvariate distributions: multiple samples # workaround for Zygote issues discussed in diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 8baa50c..150bf59 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -394,78 +394,6 @@ @info "Testing: $(nameof(dist_type(d)))" test_ad(d) end - - # Test `filldist` and `arraydist` distributions of univariate distributions - n = 2 # always use two distributions - for d in univariate_distributions - d.x isa Number || continue - - # Broken distributions - D = dist_type(d) - D <: Union{VonMises,TriangularDist} && continue - - # Skellam only fails in these tests with ReverseDiff - # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 - # PoissonBinomial fails with Zygote - # Matrix case does not work with Skellam: - # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 - filldist_broken = if D <: PoissonBinomial - ((d.broken..., :Zygote), (d.broken..., :Zygote)) - elseif D <: Chernoff - # Zygote is not broken with `filldist` - ((), ()) - else - (d.broken, d.broken) - end - arraydist_broken = if D <: PoissonBinomial - ((d.broken..., :Zygote), (d.broken..., :Zygote)) - else - (d.broken, d.broken) - end - - # Create `filldist` distribution - f = d.f - f_filldist = (θ...,) -> filldist(f(θ...), n) - d_filldist = f_filldist(d.θ...) - - # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist([f(θ...) for _ in 1:n]) - d_arraydist = f_arraydist(d.θ...) - - for (i, sz) in enumerate(((n,), (n, 2))) - # Matrix case doesn't work for continuous distributions for some reason - # now but not too important (?!) - if length(sz) == 2 && D <: ContinuousDistribution - continue - end - - # Compute compatible sample - x = fill(d.x, sz) - - # Test AD - @info "Testing: filldist($(nameof(D)), $sz)" - test_ad( - DistSpec( - f_filldist, - d.θ, - x, - d.xtrans; - broken=filldist_broken[i], - ) - ) - - @info "Testing: arraydist($(nameof(D)), $sz)" - test_ad( - DistSpec( - f_arraydist, - d.θ, - x, - d.xtrans; - broken=arraydist_broken[i], - ) - ) - end - end end @testset "Matrixvariate distributions" begin @@ -475,153 +403,5 @@ @info "Testing: $(nameof(dist_type(d)))" test_ad(d) end - - # Test `filldist` and `arraydist` distributions of univariate distributions - n = (2, 2) # always use 2 x 2 distributions - for d in univariate_distributions - d.x isa Number || continue - D = dist_type(d) - D <: DiscreteDistribution && continue - - # Broken distributions - D <: Union{VonMises,TriangularDist} && continue - - # Create `filldist` distribution - f = d.f - f_filldist = (θ...,) -> filldist(f(θ...), n...) - - # Create `arraydist` distribution - # Zygote's fill definition does not like non-numbers, so we use a workaround - f_arraydist = (θ...,) -> arraydist(reshape([f(θ...) for _ in 1:prod(n)], n)) - - # Matrix `x` - x_mat = fill(d.x, n) - - # Zygote is not broken with `filldist` + Chernoff - filldist_broken = D <: Chernoff ? () : d.broken - - # Test AD - @info "Testing: filldist($(nameof(D)), $n)" - test_ad( - DistSpec( - f_filldist, - d.θ, - x_mat, - d.xtrans; - broken=filldist_broken, - ) - ) - @info "Testing: arraydist($(nameof(D)), $n)" - test_ad( - DistSpec( - f_arraydist, - d.θ, - x_mat, - d.xtrans; - broken=d.broken, - ) - ) - - # Vector of matrices `x` - x_vec_of_mat = [fill(d.x, n) for _ in 1:2] - - # Test AD - @info "Testing: filldist($(nameof(D)), $n, 2)" - test_ad( - DistSpec( - f_filldist, - d.θ, - x_vec_of_mat, - d.xtrans; - broken=filldist_broken, - ) - ) - @info "Testing: arraydist($(nameof(D)), $n, 2)" - test_ad( - DistSpec( - f_arraydist, - d.θ, - x_vec_of_mat, - d.xtrans; - broken=d.broken, - ) - ) - end - - # test `filldist` and `arraydist` distributions of multivariate distributions - n = 2 # always use two distributions - for d in multivariate_distributions - d.x isa AbstractVector || continue - D = dist_type(d) - D <: DiscreteDistribution && continue - - # Tests are failing for matrix covariance vectorized MvNormal - if D <: Union{ - MvNormal,MvLogNormal, - DistributionsAD.TuringDenseMvNormal, - DistributionsAD.TuringDiagMvNormal, - DistributionsAD.TuringScalMvNormal, - TuringMvLogNormal - } - any(x isa Matrix for x in d.θ) && continue - end - - # Create `filldist` distribution - f = d.f - f_filldist = (θ...,) -> filldist(f(θ...), n) - - # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist([f(θ...) for _ in 1:n]) - - # Matrix `x` - x_mat = repeat(d.x, 1, n) - - # Test AD - @info "Testing: filldist($(nameof(D)), $n)" - test_ad( - DistSpec( - f_filldist, - d.θ, - x_mat, - d.xtrans; - broken=d.broken, - ) - ) - @info "Testing: arraydist($(nameof(D)), $n)" - test_ad( - DistSpec( - f_arraydist, - d.θ, - x_mat, - d.xtrans; - broken=d.broken, - ) - ) - - # Vector of matrices `x` - x_vec_of_mat = [repeat(d.x, 1, n) for _ in 1:2] - - # Test AD - @info "Testing: filldist($(nameof(D)), $n, 2)" - test_ad( - DistSpec( - f_filldist, - d.θ, - x_vec_of_mat, - d.xtrans; - broken=d.broken, - ) - ) - @info "Testing: arraydist($(nameof(D)), $n, 2)" - test_ad( - DistSpec( - f_arraydist, - d.θ, - x_vec_of_mat, - d.xtrans; - broken=d.broken, - ) - ) - end end end diff --git a/test/runtests.jl b/test/runtests.jl index c25d19e..0310ffa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ using Combinatorics using Distributions using Documenter using PDMats -import LazyArrays using Random, LinearAlgebra, Test