Skip to content

Commit 02d5e64

Browse files
committed
Use irrational constants
1 parent 5bf9bf3 commit 02d5e64

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

src/chainrules.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ ChainRulesCore.@scalar_rule(
3333
)
3434
ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω))
3535
ChainRulesCore.@scalar_rule(digamma(x), trigamma(x))
36-
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x^2))
37-
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x^2))
38-
ChainRulesCore.@scalar_rule(logerfc(x), -(2 / sqrt(π)) * exp(-x^2 - Ω))
39-
ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp^2))
40-
ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
41-
ChainRulesCore.@scalar_rule(logerfcx(x), 2 * x - (2 / sqrt(π)) * exp(-Ω))
42-
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x^2))
43-
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp^2))
36+
37+
# TODO: use `invsqrtπ` if it is added to IrrationalConstants
38+
ChainRulesCore.@scalar_rule(erf(x), (2 * exp(-x^2)) / sqrtπ)
39+
ChainRulesCore.@scalar_rule(erf(x, y), (- (2 * exp(-x^2)) / sqrtπ, (2 * exp(-y^2)) / sqrtπ))
40+
ChainRulesCore.@scalar_rule(erfc(x), - (2 * exp(-x^2)) / sqrtπ)
41+
ChainRulesCore.@scalar_rule(logerfc(x), - (2 * exp(-x^2 - Ω)) / sqrtπ)
42+
ChainRulesCore.@scalar_rule(erfcinv(x), - (sqrtπ * (exp^2) / 2)))
43+
ChainRulesCore.@scalar_rule(erfcx(x), 2 * x * Ω - 2 / sqrtπ)
44+
ChainRulesCore.@scalar_rule(logerfcx(x), 2 * (x - exp(-Ω) / sqrtπ))
45+
ChainRulesCore.@scalar_rule(erfi(x), (2 * exp(x^2)) / sqrtπ)
46+
ChainRulesCore.@scalar_rule(erfinv(x), sqrtπ * (exp^2) / 2))
4447

4548
ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
4649
ChainRulesCore.@scalar_rule(
@@ -70,8 +73,7 @@ ChainRulesCore.@scalar_rule(
7073
)
7174
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))
7275

73-
# binary
74-
ChainRulesCore.@scalar_rule(erf(x, y), (-(2 / sqrt(π)) * exp(-x^2), (2 / sqrt(π)) * exp(-y^2)))
76+
# Bessel functions
7577
ChainRulesCore.@scalar_rule(
7678
besselj(ν, x),
7779
(
@@ -135,6 +137,7 @@ ChainRulesCore.@scalar_rule(
135137
(hankelh2x- 1, x) - hankelh2x+ 1, x)) / 2 + im * Ω,
136138
),
137139
)
140+
138141
ChainRulesCore.@scalar_rule(
139142
polygamma(m, x),
140143
(
@@ -188,5 +191,5 @@ ChainRulesCore.@scalar_rule(
188191
)
189192
)
190193
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x)
191-
ChainRulesCore.@scalar_rule(sinint(x), sinc(x / π))
194+
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x))
192195
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x)

test/chainrules.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,31 @@
143143
test_scalar(cosint, x)
144144
end
145145
end
146+
147+
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/307
148+
@testset "promotions" begin
149+
# one argument
150+
for f in (erf, erfc, logerfc, erfcinv, logerfcx, erfi, erfinv, sinint)
151+
_, ẏ = frule((NoTangent(), 1f0), f, 1f0)
152+
@testisa Float32
153+
_, back = rrule(f, 1f0)
154+
_, x̄ = back(1f0)
155+
@testisa Float32
156+
end
157+
158+
# two arguments
159+
_, ẏ = frule((NoTangent(), 1f0, 1f0), erf, 1f0, 1f0)
160+
@testisa Float32
161+
_, back = rrule(erf, 1f0, 1f0)
162+
_, x̄ = back(1f0)
163+
@testisa Float32
164+
165+
# currently broken, can be fixed if `invsqrtπ` is available:
166+
# https://github.com/JuliaMath/IrrationalConstants.jl/pull/8#issuecomment-925828753
167+
_, ẏ = frule((NoTangent(), 1f0), erfcx, 1f0)
168+
@test_brokenisa Float32
169+
_, back = rrule(erfcx, 1f0)
170+
_, x̄ = back(1f0)
171+
@testisa Float32
172+
end
146173
end

0 commit comments

Comments
 (0)