Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.70"
version = "0.8.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,8 +12,8 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.44"
ChainRulesTestUtils = "0.6.8"
ChainRulesCore = "0.10"
ChainRulesTestUtils = "0.7"
Compat = "3.30"
FiniteDifferences = "0.12.8"
Reexport = "0.2, 1"
Expand Down
1 change: 0 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module ChainRules

using Reexport
@reexport using ChainRulesCore
export Zero, DoesNotExist, Composite, AbstractDifferential

using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using Compat
Expand Down
16 changes: 8 additions & 8 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
A_dims = size(A)
function reshape_pullback(Ȳ)
return (NO_FIELDS, reshape(Ȳ, A_dims), NoTangent())
return (NoTangent(), reshape(Ȳ, A_dims), NoTangent())
end
return reshape(A, dims), reshape_pullback
end
Expand All @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
function reshape_pullback(Ȳ)
∂A = reshape(Ȳ, A_dims)
∂dims = broadcast(_ -> NoTangent(), dims)
return (NO_FIELDS, ∂A, ∂dims...)
return (NoTangent(), ∂A, ∂dims...)
end
return reshape(A, dims...), reshape_pullback
end
Expand All @@ -28,7 +28,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
function hcat_pullback(Ȳ)
Xs = (A, Bs...)
ntuple(length(Bs) + 2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 1 && return NoTangent()

i = full_i - 1
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
Expand All @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe
pre = post - diff + 1
return ΔY[:, pre:post]
end
return (NO_FIELDS, NoTangent(), ∂As)
return (NoTangent(), NoTangent(), ∂As)
end
return reduce(hcat, As), reduce_hcat_pullback
end
Expand All @@ -68,7 +68,7 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
u = l + size(Bs[i], 1)
copy(selectdim(Ȳ, 1, l+1:u))
end
return (NO_FIELDS, ∂A, ∂Bs...)
return (NoTangent(), ∂A, ∂Bs...)
end
return vcat(A, Bs...), vcat_pullback
end
Expand All @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe
pre = post - diff + 1
return ΔY[pre:post, :]
end
return (NO_FIELDS, NoTangent(), ∂As)
return (NoTangent(), NoTangent(), ∂As)
end
return reduce(vcat, As), reduce_vcat_pullback
end
Expand All @@ -92,14 +92,14 @@ end

function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
function fill_pullback(Ȳ)
return (NO_FIELDS, sum(Ȳ), NoTangent())
return (NoTangent(), sum(Ȳ), NoTangent())
end
return fill(value, dims), fill_pullback
end

function rrule(::typeof(fill), value::Any, dims::Int...)
function fill_pullback(Ȳ)
return (NO_FIELDS, sum(Ȳ), ntuple(_->NoTangent(), length(dims))...)
return (NoTangent(), sum(Ȳ), ntuple(_->NoTangent(), length(dims))...)
end
return fill(value, dims), fill_pullback
end
30 changes: 15 additions & 15 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end
function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NO_FIELDS, -Ω' * ΔΩ * Ω'
return NoTangent(), -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end
Expand All @@ -26,7 +26,7 @@ function rrule(
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
NoTangent(),
InplaceableThunk(
@thunk(Ȳ * B'),
X̄ -> mul!(X̄, Ȳ, B', true, true)
Expand All @@ -48,7 +48,7 @@ function rrule(
function times_pullback(Ȳ)
@assert size(B, 1) === 1 # otherwise primal would have failed.
return (
NO_FIELDS,
NoTangent(),
InplaceableThunk(
@thunk(Ȳ * vec(B')),
X̄ -> mul!(X̄, Ȳ, vec(B'), true, true)
Expand All @@ -67,7 +67,7 @@ function rrule(
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
NoTangent(),
@thunk(dot(Ȳ, B)'),
InplaceableThunk(
@thunk(A' * Ȳ),
Expand All @@ -83,7 +83,7 @@ function rrule(
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
NoTangent(),
InplaceableThunk(
@thunk(A' * Ȳ),
X̄ -> mul!(X̄, conj(A), Ȳ, true, true)
Expand Down Expand Up @@ -118,7 +118,7 @@ function rrule(
)
)
addon = if z isa Bool
DoesNotExist()
NoTangent()
elseif z isa Number
@thunk(sum(Ȳ))
else
Expand All @@ -127,7 +127,7 @@ function rrule(
dz -> sum!(dz, Ȳ; init=false)
)
end
(NO_FIELDS, matmul..., addon)
(NoTangent(), matmul..., addon)
end
return muladd(A, B, z), muladd_pullback_1
end
Expand All @@ -148,7 +148,7 @@ function rrule(
@thunk(ut' .* dy),
dv -> dv .+= ut' .* dy
)
(NO_FIELDS, ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy)
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : dy)
end
return muladd(ut, v, z), muladd_pullback_2
end
Expand All @@ -166,7 +166,7 @@ function rrule(
@thunk(vec(sum(u .* conj.(Ȳ), dims=1))'),
)
addon = if z isa Bool
DoesNotExist()
NoTangent()
elseif z isa Number
@thunk(sum(Ȳ))
else
Expand All @@ -175,7 +175,7 @@ function rrule(
dz -> sum!(dz, Ȳ; init=false)
)
end
(NO_FIELDS, proj..., addon)
(NoTangent(), proj..., addon)
end
return muladd(u, vt, z), muladd_pullback_3
end
Expand All @@ -197,7 +197,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
(NoTangent(), ∂A, ∂B)
end
return C, slash_pullback
end
Expand All @@ -217,7 +217,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
end
∂B = @thunk A' \ Ȳ
return NO_FIELDS, ∂A, ∂B
return NoTangent(), ∂A, ∂B
end
return Y, backslash_pullback

Expand All @@ -230,15 +230,15 @@ end
function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
return (NoTangent(), @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
return (NoTangent(), @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end
Expand All @@ -249,7 +249,7 @@ end

function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NO_FIELDS, InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ)
return NoTangent(), InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ)
end
return -x, negation_pullback
end
16 changes: 8 additions & 8 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

function rrule(::typeof(adjoint), z::Number)
adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ')
adjoint_pullback(ΔΩ) = (NoTangent(), ΔΩ')
return (z', adjoint_pullback)
end

Expand All @@ -22,7 +22,7 @@ frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz))

function rrule(::typeof(real), z::Number)
# add zero(z) to embed the real number in the same number type as z
real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z))
real_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) + zero(z))
return (real(z), real_pullback)
end

Expand All @@ -33,7 +33,7 @@ end
frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz))

function rrule(::typeof(imag), z::Complex)
imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im)
imag_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) * im)
return (imag(z), imag_pullback)
end

Expand All @@ -45,15 +45,15 @@ function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex
end

function rrule(::Type{T}, z::Complex) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ))
Complex_pullback(ΔΩ) = (NoTangent(), Complex(ΔΩ))
return (T(z), Complex_pullback)
end
function rrule(::Type{T}, x::Real) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ))
Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ))
return (T(x), Complex_pullback)
end
function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ))
Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ), imag(ΔΩ))
return (T(x, y), Complex_pullback)
end

Expand All @@ -70,7 +70,7 @@ end
function rrule(::typeof(hypot), z::Complex)
Ω = hypot(z)
function hypot_pullback(ΔΩ)
return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
end
return (Ω, hypot_pullback)
end
Expand Down Expand Up @@ -148,7 +148,7 @@ end

function rrule(::typeof(identity), x)
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
return (NoTangent(), ȳ)
end
return (x, identity_pullback)
end
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/Base/evalpoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if VERSION ≥ v"1.4"
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NO_FIELDS, ∂x, ∂p
return NoTangent(), ∂x, ∂p
end
return y, evalpoly_pullback
end
Expand Down
Loading