Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.70"
version = "0.8.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.44"
ChainRulesCore = "0.10"
ChainRulesTestUtils = "0.6.8"
Compat = "3.30"
FiniteDifferences = "0.12.8"
Expand Down
17 changes: 0 additions & 17 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,4 @@ include("rulesets/LinearAlgebra/factorization.jl")

include("rulesets/Random/random.jl")

# Note: The following is only required because package authors sometimes do not
# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons.
# So we define them here for them.
function __init__()
@require NaNMath="77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin
include("rulesets/packages/NaNMath.jl")
end

# Note: drop SpecialFunctions dependency in next breaking release
# https://github.com/JuliaDiff/ChainRules.jl/issues/319
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
if !isdefined(SpecialFunctions, :ChainRulesCore)
include("rulesets/packages/SpecialFunctions.jl")
end
end
end

end # module
16 changes: 8 additions & 8 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
A_dims = size(A)
function reshape_pullback(Ȳ)
return (NO_FIELDS, reshape(Ȳ, A_dims), NoTangent())
return (NoTangent(), reshape(Ȳ, A_dims), NoTangent())
end
return reshape(A, dims), reshape_pullback
end
Expand All @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
function reshape_pullback(Ȳ)
∂A = reshape(Ȳ, A_dims)
∂dims = broadcast(_ -> NoTangent(), dims)
return (NO_FIELDS, ∂A, ∂dims...)
return (NoTangent(), ∂A, ∂dims...)
end
return reshape(A, dims...), reshape_pullback
end
Expand All @@ -28,7 +28,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
function hcat_pullback(Ȳ)
Xs = (A, Bs...)
ntuple(length(Bs) + 2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 1 && return NoTangent()

i = full_i - 1
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
Expand All @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe
pre = post - diff + 1
return ΔY[:, pre:post]
end
return (NO_FIELDS, NoTangent(), ∂As)
return (NoTangent(), NoTangent(), ∂As)
end
return reduce(hcat, As), reduce_hcat_pullback
end
Expand All @@ -68,7 +68,7 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
u = l + size(Bs[i], 1)
copy(selectdim(Ȳ, 1, l+1:u))
end
return (NO_FIELDS, ∂A, ∂Bs...)
return (NoTangent(), ∂A, ∂Bs...)
end
return vcat(A, Bs...), vcat_pullback
end
Expand All @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe
pre = post - diff + 1
return ΔY[pre:post, :]
end
return (NO_FIELDS, NoTangent(), ∂As)
return (NoTangent(), NoTangent(), ∂As)
end
return reduce(vcat, As), reduce_vcat_pullback
end
Expand All @@ -92,14 +92,14 @@ end

function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
function fill_pullback(Ȳ)
return (NO_FIELDS, sum(Ȳ), NoTangent())
return (NoTangent(), sum(Ȳ), NoTangent())
end
return fill(value, dims), fill_pullback
end

function rrule(::typeof(fill), value::Any, dims::Int...)
function fill_pullback(Ȳ)
return (NO_FIELDS, sum(Ȳ), ntuple(_->NoTangent(), length(dims))...)
return (NoTangent(), sum(Ȳ), ntuple(_->NoTangent(), length(dims))...)
end
return fill(value, dims), fill_pullback
end
26 changes: 13 additions & 13 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end
function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NO_FIELDS, -Ω' * ΔΩ * Ω'
return NoTangent(), -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end
Expand All @@ -26,7 +26,7 @@ function rrule(
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
NoTangent(),
InplaceableThunk(
@thunk(Ȳ * B'),
X̄ -> mul!(X̄, Ȳ, B', true, true)
Expand All @@ -48,7 +48,7 @@ function rrule(
function times_pullback(Ȳ)
@assert size(B, 1) === 1 # otherwise primal would have failed.
return (
NO_FIELDS,
NoTangent(),
InplaceableThunk(
@thunk(Ȳ * vec(B')),
X̄ -> mul!(X̄, Ȳ, vec(B'), true, true)
Expand All @@ -67,7 +67,7 @@ function rrule(
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
NoTangent(),
@thunk(dot(Ȳ, B)'),
InplaceableThunk(
@thunk(A' * Ȳ),
Expand All @@ -83,7 +83,7 @@ function rrule(
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
NoTangent(),
InplaceableThunk(
@thunk(A' * Ȳ),
X̄ -> mul!(X̄, conj(A), Ȳ, true, true)
Expand Down Expand Up @@ -127,7 +127,7 @@ function rrule(
dz -> sum!(dz, Ȳ; init=false)
)
end
(NO_FIELDS, matmul..., addon)
(NoTangent(), matmul..., addon)
end
return muladd(A, B, z), muladd_pullback_1
end
Expand All @@ -148,7 +148,7 @@ function rrule(
@thunk(ut' .* dy),
dv -> dv .+= ut' .* dy
)
(NO_FIELDS, ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy)
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy)
end
return muladd(ut, v, z), muladd_pullback_2
end
Expand All @@ -175,7 +175,7 @@ function rrule(
dz -> sum!(dz, Ȳ; init=false)
)
end
(NO_FIELDS, proj..., addon)
(NoTangent(), proj..., addon)
end
return muladd(u, vt, z), muladd_pullback_3
end
Expand All @@ -197,7 +197,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
(NoTangent(), ∂A, ∂B)
end
return C, slash_pullback
end
Expand All @@ -217,7 +217,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
end
∂B = @thunk A' \ Ȳ
return NO_FIELDS, ∂A, ∂B
return NoTangent(), ∂A, ∂B
end
return Y, backslash_pullback

Expand All @@ -230,15 +230,15 @@ end
function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
return (NoTangent(), @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
return (NoTangent(), @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end
Expand All @@ -249,7 +249,7 @@ end

function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NO_FIELDS, InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ)
return NoTangent(), InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ)
end
return -x, negation_pullback
end
16 changes: 8 additions & 8 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

function rrule(::typeof(adjoint), z::Number)
adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ')
adjoint_pullback(ΔΩ) = (NoTangent(), ΔΩ')
return (z', adjoint_pullback)
end

Expand All @@ -22,7 +22,7 @@ frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz))

function rrule(::typeof(real), z::Number)
# add zero(z) to embed the real number in the same number type as z
real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z))
real_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) + zero(z))
return (real(z), real_pullback)
end

Expand All @@ -33,7 +33,7 @@ end
frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz))

function rrule(::typeof(imag), z::Complex)
imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im)
imag_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) * im)
return (imag(z), imag_pullback)
end

Expand All @@ -45,15 +45,15 @@ function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex
end

function rrule(::Type{T}, z::Complex) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ))
Complex_pullback(ΔΩ) = (NoTangent(), Complex(ΔΩ))
return (T(z), Complex_pullback)
end
function rrule(::Type{T}, x::Real) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ))
Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ))
return (T(x), Complex_pullback)
end
function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ))
Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ), imag(ΔΩ))
return (T(x, y), Complex_pullback)
end

Expand All @@ -70,7 +70,7 @@ end
function rrule(::typeof(hypot), z::Complex)
Ω = hypot(z)
function hypot_pullback(ΔΩ)
return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
end
return (Ω, hypot_pullback)
end
Expand Down Expand Up @@ -148,7 +148,7 @@ end

function rrule(::typeof(identity), x)
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
return (NoTangent(), ȳ)
end
return (x, identity_pullback)
end
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/Base/evalpoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if VERSION ≥ v"1.4"
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NO_FIELDS, ∂x, ∂p
return NoTangent(), ∂x, ∂p
end
return y, evalpoly_pullback
end
Expand Down
22 changes: 11 additions & 11 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ let
## sin
function rrule(::typeof(sin), x::Number)
sinx, cosx = sincos(x)
sin_pullback(Δy) = (NO_FIELDS, cosx' * Δy)
sin_pullback(Δy) = (NoTangent(), cosx' * Δy)
return (sinx, sin_pullback)
end

Expand All @@ -23,7 +23,7 @@ let
## cos
function rrule(::typeof(cos), x::Number)
sinx, cosx = sincos(x)
cos_pullback(Δy) = (NO_FIELDS, -sinx' * Δy)
cos_pullback(Δy) = (NoTangent(), -sinx' * Δy)
return (cosx, cos_pullback)
end

Expand Down Expand Up @@ -75,7 +75,7 @@ let
Ω = abs(x)
function abs_pullback(ΔΩ)
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
return (NO_FIELDS, signx * real(ΔΩ))
return (NoTangent(), signx * real(ΔΩ))
end
return Ω, abs_pullback
end
Expand All @@ -88,7 +88,7 @@ let
function rrule(::typeof(abs2), z::Union{Real, Complex})
function abs2_pullback(ΔΩ)
Δu = real(ΔΩ)
return (NO_FIELDS, 2real(ΔΩ)*z)
return (NoTangent(), 2real(ΔΩ)*z)
end
return abs2(z), abs2_pullback
end
Expand All @@ -99,7 +99,7 @@ let
end
function rrule(::typeof(conj), z::Union{Real, Complex})
function conj_pullback(ΔΩ)
return (NO_FIELDS, conj(ΔΩ))
return (NoTangent(), conj(ΔΩ))
end
return conj(z), conj_pullback
end
Expand All @@ -115,11 +115,11 @@ let

function rrule(::typeof(angle), x::Real)
function angle_pullback(ΔΩ::Real)
return (NO_FIELDS, ZeroTangent())
return (NoTangent(), ZeroTangent())
end
function angle_pullback(ΔΩ)
Δu, Δv = reim(ΔΩ)
return (NO_FIELDS, im*Δu/ifelse(iszero(x), one(x), x))
return (NoTangent(), im*Δu/ifelse(iszero(x), one(x), x))
# `ifelse` is applied only to denominator to ensure type-stability.
end
return angle(x), angle_pullback
Expand All @@ -130,7 +130,7 @@ let
Δu, Δv = reim(ΔΩ)
# `ifelse` is applied only to denominator to ensure type-stability.
n = ifelse(iszero(z), one(real(z)), abs2(z))
return (NO_FIELDS, (-y + im*x)*Δu/n)
return (NoTangent(), (-y + im*x)*Δu/n)
end
return angle(z), angle_pullback
end
Expand All @@ -155,7 +155,7 @@ let
Ω = hypot(x, y)
function hypot_pullback(ΔΩ)
c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)
return (NO_FIELDS, c * x, c * y)
return (NoTangent(), c * x, c * y)
end
return (Ω, hypot_pullback)
end
Expand Down Expand Up @@ -198,7 +198,7 @@ let
Ω = x isa Real ? sign(x) : x / n
function sign_pullback(ΔΩ)
∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im
return (NO_FIELDS, ∂x)
return (NoTangent(), ∂x)
end
return Ω, sign_pullback
end
Expand All @@ -214,7 +214,7 @@ let

function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, ΔΩ * y', x' * ΔΩ)
return (NoTangent(), ΔΩ * y', x' * ΔΩ)
end
return x * y, times_pullback
end
Expand Down
Loading