Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.22"
version = "0.6.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,7 +10,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -30,7 +29,6 @@ Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.23.3, 0.24"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10.6"
NaNMath = "0.3"
PDMats = "0.9, 0.10, 0.11"
Requires = "1"
Expand Down
4 changes: 0 additions & 4 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import StatsFuns: logsumexp,
nbetalogpdf
import Distributions: MvNormal,
MvLogNormal,
poissonbinomial_pdf_fft,
logpdf,
quantile,
PoissonBinomial,
Expand Down Expand Up @@ -65,9 +64,6 @@ include("zygote.jl")
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
include("forwarddiff.jl")

# loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft`
include("zygote_forwarddiff.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
Expand Down
108 changes: 82 additions & 26 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,75 +11,131 @@
(c, -c, z),
)

# StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106

## Beta ##

@scalar_rule(
betalogpdf::Real, β::Real, x::Number),
@setup(di = digamma+ β)),
@setup(z = digamma+ β)),
(
@thunk(log(x) - digamma) + di),
@thunk(log(1 - x) - digamma) + di),
@thunk((α - 1)/x + (1 - β)/(1 - x)),
log(x) + z - digamma(α),
log1p(-x) + z - digamma(β),
(α - 1) / x + (1 - β) / (1 - x),
),
)

## Gamma ##

@scalar_rule(
gammalogpdf(k::Real, θ::Real, x::Number),
@setup(
invθ = inv(θ),
xoθ = invθ * x,
z = xoθ - k,
),
(
@thunk(-digamma(k) - log(θ) + log(x)),
@thunk(-k/θ + x/θ^2),
@thunk((k - 1)/x - 1/θ),
log(xoθ) - digamma(k),
invθ * z,
- (1 + z) / x,
),
)

## Chisq ##

@scalar_rule(
chisqlogpdf(k::Real, x::Number),
@setup(ko2 = k / 2),
(@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)),
@setup(hk = k / 2),
(
(log(x) - logtwo - digamma(hk)) / 2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is logtwo from? I see it's from StatsFuns, but I can't find the definition in StatsFuns 😕

I ask because I worry it might lead to undeseriable type-promotion, e.g. Float32 to Float64.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's defined in LogExpFunctions and just reexported from StatsFuns (the PR was finally merged some days ago 🎉). It is defined as an Irrational and hence should avoid type promotions if possible, e.g:

julia> logtwo - 3.4
-2.7068528194400545

julia> logtwo - 3.4f0
-2.706853f0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, good 👍 Yeah in this case it should be fine, but there are several cases in DiffRules.jl (and I'm pretty certain the same ones exist in ChainRules.jl?) where the operations are written in a way such that Irrational is first converted into Float64, and thus the dual is promoted to Float64, despite the primal being Float32 (JuliaDiff/DiffRules.jl#55).

And the annoying bit is that some of these constants are defined in StatsFuns.jl, some in LogExpFunctions.jl, etc., so it's difficult to re-use them in DiffRules.jl 😕
Though maybe I should actually just add all those constants to that PR instead of doing oftype everywhere, hmm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only a problem if you perform operations with the constants (such as sqrt) before they are promoted to the correct type, which is not the case here.

(hk - 1) / x - one(hk) / 2,
),
)

## FDist ##

@scalar_rule(
fdistlogpdf(v1::Real, v2::Real, x::Number),
fdistlogpdf(ν1::Real, ν2::Real, x::Number),
@setup(
temp1 = v1 * x + v2,
temp2 = log(temp1),
vsum = v1 + v2,
temp3 = vsum / temp1,
temp4 = digamma(vsum / 2),
xν1 = x * ν1,
temp1 = xν1 + ν2,
a = (x - 1) / temp1,
νsum = ν1 + ν2,
di = digamma(νsum / 2),
),
(
@thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2),
@thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2),
@thunk(v1 / 2 * (1 / x - temp3) - 1 / x),
(-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2,
(-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2,
((ν1 - 2) / x - ν1 * νsum / temp1) / 2,
),
)

## TDist ##

@scalar_rule(
tdistlogpdf(v::Real, x::Number),
tdistlogpdf::Real, x::Number),
@setup(
νp1 = ν + 1,
xsq = x^2,
invν = inv(ν),
a = xsq * invν,
b = νp1 /+ xsq),
),
(
@thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2),
@thunk(-x * (v + 1) / (v + x^2)),
)
(digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2,
- x * b,
),
)

## Binomial ##

@scalar_rule(
binomlogpdf(n::Int, p::Real, x::Int),
(DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()),
binomlogpdf(n::Real, p::Real, k::Real),
@setup(z = digamma(n - k + 1)),
(
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
(k / p - n) / (1 - p),
z - digamma(k + 1) + logit(p),
),
)

## Poisson ##

@scalar_rule(
poislogpdf(v::Real, x::Int),
(x / v - 1, DoesNotExist()),
poislogpdf::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
)

## PoissonBinomial

function ChainRulesCore.rrule(
::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real}
)
y = Distributions.poissonbinomial_pdf_fft(p)
A = poissonbinomial_partialderivatives(p)
function poissonbinomial_pdf_fft_pullback(Δy)
= InplaceableThunk(
@thunk(A * Δy),
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
)
return (NO_FIELDS, p̄)
end
return y, poissonbinomial_pdf_fft_pullback
end

if isdefined(Distributions, :poissonbinomial_pdf)
function ChainRulesCore.rrule(
::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real}
)
y = Distributions.poissonbinomial_pdf(p)
A = poissonbinomial_partialderivatives(p)
function poissonbinomial_pdf_pullback(Δy)
= InplaceableThunk(
@thunk(A * Δy),
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
)
return (NO_FIELDS, p̄)
end
return y, poissonbinomial_pdf_pullback
end
end
42 changes: 42 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,45 @@ parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

@non_differentiable adapt_randn(::Any...)

# PoissonBinomial

# compute matrix of partial derivatives [∂P(X=j-1)/∂pᵢ]_{i=1,…,n; j=1,…,n+1}
#
# This uses the same dynamic programming "trick" as for the computation of the primals
# in Distributions
#
# Reference (for the primal):
#
# Marlin A. Thomas & Audrey E. Taub (1982)
# Calculating binomial probabilities when the trial probabilities are unequal,
# Journal of Statistical Computation and Simulation, 14:2, 125-131, DOI: 10.1080/00949658208810534
function poissonbinomial_partialderivatives(p)
n = length(p)
A = zeros(eltype(p), n, n + 1)
@inbounds for j in 1:n
A[j, end] = 1
end
@inbounds for (i, pi) in enumerate(p)
qi = 1 - pi
for k in (n - i + 1):n
kp1 = k + 1
for j in 1:(i - 1)
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
end
for j in (i+1):n
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
end
end
for j in 1:(i-1)
A[j, end] *= pi
end
for j in (i+1):n
A[j, end] *= pi
end
end
@inbounds for j in 1:n, i in 1:n
A[i, j] -= A[i, j+1]
end
return A
end
16 changes: 6 additions & 10 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,26 +261,22 @@ end
PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) =
TuringPoissonBinomial(p; check_args = check_args)

# TODO: add adjoints without ForwardDiff
poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x)
@grad function poissonbinomial_pdf_fft(x::TrackedArray)
x_data = data(x)
T = eltype(x_data)
fft = poissonbinomial_pdf_fft(x_data)
return fft, Δ -> begin
((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x_data)::Matrix{T})' * Δ,)
end
value = poissonbinomial_pdf_fft(x_data)
A = poissonbinomial_partialderivatives(x_data)
poissonbinomial_pdf_fft_pullback(Δ) = (A * Δ,)
return value, poissonbinomial_pdf_fft_pullback
end

if isdefined(Distributions, :poissonbinomial_pdf)
Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x)
@grad function Distributions.poissonbinomial_pdf(x::TrackedArray)
x_data = data(x)
T = eltype(x_data)
value = Distributions.poissonbinomial_pdf(x_data)
function poissonbinomial_pdf_pullback(Δ)
return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,)
end
A = poissonbinomial_partialderivatives(x_data)
poissonbinomial_pdf_pullback(Δ) = (A * Δ,)
return value, poissonbinomial_pdf_pullback
end
end
Expand Down
8 changes: 0 additions & 8 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ ZygoteRules.@adjoint function Distributions.Uniform(args...)
return ZygoteRules.pullback(TuringUniform, args...)
end

## PoissonBinomial ##

# Zygote loads ForwardDiff, so this dummy adjoint should never be needed.
# The adjoint that is used for `poissonbinomial_pdf_fft` is defined in `src/zygote_forwarddiff.jl`
# ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real
# error("This needs ForwardDiff. `using ForwardDiff` should fix this error.")
# end

## Product

# Tests with `Kolmogorov` seem to fail otherwise?!
Expand Down
20 changes: 0 additions & 20 deletions src/zygote_forwarddiff.jl

This file was deleted.

4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -16,7 +17,8 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesTestUtils = "0.5.3, 0.6"
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.6.3"
Combinatorics = "1.0.2"
Distributions = "0.24.3"
FiniteDifferences = "0.11.3, 0.12"
Expand Down
Loading