-
Notifications
You must be signed in to change notification settings - Fork 30
Update ChainRules definitions and add differential for PoissonBinomial pdf #162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5392ef9
9389fbf
9b0e29e
691ab5f
5ec49b0
173d76f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,75 +11,131 @@ | |
| (c, -c, z), | ||
| ) | ||
|
|
||
| # StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106 | ||
|
|
||
| ## Beta ## | ||
|
|
||
| @scalar_rule( | ||
| betalogpdf(α::Real, β::Real, x::Number), | ||
| @setup(di = digamma(α + β)), | ||
| @setup(z = digamma(α + β)), | ||
| ( | ||
| @thunk(log(x) - digamma(α) + di), | ||
| @thunk(log(1 - x) - digamma(β) + di), | ||
| @thunk((α - 1)/x + (1 - β)/(1 - x)), | ||
| log(x) + z - digamma(α), | ||
| log1p(-x) + z - digamma(β), | ||
| (α - 1) / x + (1 - β) / (1 - x), | ||
| ), | ||
| ) | ||
|
|
||
| ## Gamma ## | ||
|
|
||
| @scalar_rule( | ||
| gammalogpdf(k::Real, θ::Real, x::Number), | ||
| @setup( | ||
| invθ = inv(θ), | ||
| xoθ = invθ * x, | ||
| z = xoθ - k, | ||
| ), | ||
| ( | ||
| @thunk(-digamma(k) - log(θ) + log(x)), | ||
| @thunk(-k/θ + x/θ^2), | ||
| @thunk((k - 1)/x - 1/θ), | ||
| log(xoθ) - digamma(k), | ||
| invθ * z, | ||
| - (1 + z) / x, | ||
| ), | ||
| ) | ||
|
|
||
| ## Chisq ## | ||
|
|
||
| @scalar_rule( | ||
| chisqlogpdf(k::Real, x::Number), | ||
| @setup(ko2 = k / 2), | ||
| (@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)), | ||
| @setup(hk = k / 2), | ||
| ( | ||
| (log(x) - logtwo - digamma(hk)) / 2, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is I ask because I worry it might lead to undeseriable type-promotion, e.g.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's defined in LogExpFunctions and just reexported from StatsFuns (the PR was finally merged some days ago 🎉). It is defined as an julia> logtwo - 3.4
-2.7068528194400545
julia> logtwo - 3.4f0
-2.706853f0
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good, good 👍 Yeah in this case it should be fine, but there are several cases in DiffRules.jl (and I'm pretty certain the same ones exist in ChainRules.jl?) where the operations are written in a way such that And the annoying bit is that some of these constants are defined in StatsFuns.jl, some in LogExpFunctions.jl, etc., so it's difficult to re-use them in DiffRules.jl 😕
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's only a problem if you perform operations with the constants (such as |
||
| (hk - 1) / x - one(hk) / 2, | ||
| ), | ||
| ) | ||
|
|
||
| ## FDist ## | ||
|
|
||
| @scalar_rule( | ||
| fdistlogpdf(v1::Real, v2::Real, x::Number), | ||
| fdistlogpdf(ν1::Real, ν2::Real, x::Number), | ||
| @setup( | ||
| temp1 = v1 * x + v2, | ||
| temp2 = log(temp1), | ||
| vsum = v1 + v2, | ||
| temp3 = vsum / temp1, | ||
| temp4 = digamma(vsum / 2), | ||
| xν1 = x * ν1, | ||
| temp1 = xν1 + ν2, | ||
| a = (x - 1) / temp1, | ||
| νsum = ν1 + ν2, | ||
| di = digamma(νsum / 2), | ||
| ), | ||
| ( | ||
| @thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2), | ||
| @thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2), | ||
| @thunk(v1 / 2 * (1 / x - temp3) - 1 / x), | ||
| (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, | ||
| (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, | ||
| ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, | ||
| ), | ||
| ) | ||
|
|
||
| ## TDist ## | ||
|
|
||
| @scalar_rule( | ||
| tdistlogpdf(v::Real, x::Number), | ||
| tdistlogpdf(ν::Real, x::Number), | ||
| @setup( | ||
| νp1 = ν + 1, | ||
| xsq = x^2, | ||
| invν = inv(ν), | ||
| a = xsq * invν, | ||
| b = νp1 / (ν + xsq), | ||
| ), | ||
| ( | ||
| @thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2), | ||
| @thunk(-x * (v + 1) / (v + x^2)), | ||
| ) | ||
| (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, | ||
| - x * b, | ||
| ), | ||
| ) | ||
|
|
||
| ## Binomial ## | ||
|
|
||
| @scalar_rule( | ||
| binomlogpdf(n::Int, p::Real, x::Int), | ||
| (DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()), | ||
| binomlogpdf(n::Real, p::Real, k::Real), | ||
| @setup(z = digamma(n - k + 1)), | ||
| ( | ||
| digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), | ||
| (k / p - n) / (1 - p), | ||
| z - digamma(k + 1) + logit(p), | ||
| ), | ||
| ) | ||
|
|
||
| ## Poisson ## | ||
|
|
||
| @scalar_rule( | ||
| poislogpdf(v::Real, x::Int), | ||
| (x / v - 1, DoesNotExist()), | ||
| poislogpdf(λ::Number, x::Number), | ||
| ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)), | ||
| ) | ||
|
|
||
| ## PoissonBinomial | ||
|
|
||
| function ChainRulesCore.rrule( | ||
| ::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real} | ||
| ) | ||
| y = Distributions.poissonbinomial_pdf_fft(p) | ||
| A = poissonbinomial_partialderivatives(p) | ||
| function poissonbinomial_pdf_fft_pullback(Δy) | ||
| p̄ = InplaceableThunk( | ||
| @thunk(A * Δy), | ||
| Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true), | ||
| ) | ||
| return (NO_FIELDS, p̄) | ||
| end | ||
| return y, poissonbinomial_pdf_fft_pullback | ||
| end | ||
|
|
||
| if isdefined(Distributions, :poissonbinomial_pdf) | ||
| function ChainRulesCore.rrule( | ||
| ::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real} | ||
| ) | ||
| y = Distributions.poissonbinomial_pdf(p) | ||
| A = poissonbinomial_partialderivatives(p) | ||
| function poissonbinomial_pdf_pullback(Δy) | ||
| p̄ = InplaceableThunk( | ||
| @thunk(A * Δy), | ||
| Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true), | ||
| ) | ||
| return (NO_FIELDS, p̄) | ||
| end | ||
| return y, poissonbinomial_pdf_pullback | ||
| end | ||
| end | ||
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.