From 5aad94a57206fa86f6228d11d3675d2232f4d4d8 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 14:30:54 +0100 Subject: [PATCH 1/9] replace NO_FIELDS by NoTangent() --- src/rulesets/Base/array.jl | 16 ++--- src/rulesets/Base/arraymath.jl | 20 +++--- src/rulesets/Base/base.jl | 16 ++--- src/rulesets/Base/evalpoly.jl | 2 +- src/rulesets/Base/fastmath_able.jl | 22 +++---- src/rulesets/Base/indexing.jl | 2 +- src/rulesets/Base/mapreduce.jl | 4 +- src/rulesets/Base/sort.jl | 4 +- src/rulesets/LinearAlgebra/blas.jl | 22 +++---- src/rulesets/LinearAlgebra/dense.jl | 30 ++++----- src/rulesets/LinearAlgebra/factorization.jl | 32 +++++----- src/rulesets/LinearAlgebra/matfun.jl | 4 +- src/rulesets/LinearAlgebra/norm.jl | 40 ++++++------ src/rulesets/LinearAlgebra/structured.jl | 66 ++++++++++---------- src/rulesets/LinearAlgebra/symmetric.jl | 24 +++---- src/rulesets/Random/random.jl | 2 +- src/rulesets/Statistics/statistics.jl | 2 +- test/rulesets/Base/base.jl | 4 +- test/rulesets/LinearAlgebra/factorization.jl | 16 ++--- test/rulesets/LinearAlgebra/norm.jl | 10 +-- test/rulesets/LinearAlgebra/structured.jl | 8 +-- test/rulesets/LinearAlgebra/symmetric.jl | 26 ++++---- 22 files changed, 186 insertions(+), 186 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 3b3996373..514090314 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -5,7 +5,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) A_dims = size(A) function reshape_pullback(Ȳ) - return (NO_FIELDS, reshape(Ȳ, A_dims), DoesNotExist()) + return (NoTangent(), reshape(Ȳ, A_dims), DoesNotExist()) end return reshape(A, dims), reshape_pullback end @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) ∂A = reshape(Ȳ, A_dims) ∂dims = broadcast(_ -> DoesNotExist(), dims) - return (NO_FIELDS, ∂A, ∂dims...) + return (NoTangent(), ∂A, ∂dims...) end return reshape(A, dims...), reshape_pullback end @@ -28,7 +28,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) function hcat_pullback(Ȳ) Xs = (A, Bs...) ntuple(length(Bs) + 2) do full_i - full_i == 1 && return NO_FIELDS + full_i == 1 && return NoTangent() i = full_i - 1 l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[:, pre:post] end - return (NO_FIELDS, DoesNotExist(), ∂As) + return (NoTangent(), DoesNotExist(), ∂As) end return reduce(hcat, As), reduce_hcat_pullback end @@ -68,7 +68,7 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) u = l + size(Bs[i], 1) copy(selectdim(Ȳ, 1, l+1:u)) end - return (NO_FIELDS, ∂A, ∂Bs...) + return (NoTangent(), ∂A, ∂Bs...) end return vcat(A, Bs...), vcat_pullback end @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[pre:post, :] end - return (NO_FIELDS, DoesNotExist(), ∂As) + return (NoTangent(), DoesNotExist(), ∂As) end return reduce(vcat, As), reduce_vcat_pullback end @@ -92,14 +92,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), DoesNotExist()) + return (NoTangent(), sum(Ȳ), DoesNotExist()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), ntuple(_->DoesNotExist(), length(dims))...) + return (NoTangent(), sum(Ȳ), ntuple(_->DoesNotExist(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 023ccf7f4..3ce2e851c 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -10,7 +10,7 @@ end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) function inv_pullback(ΔΩ) - return NO_FIELDS, -Ω' * ΔΩ * Ω' + return NoTangent(), -Ω' * ΔΩ * Ω' end return Ω, inv_pullback end @@ -26,7 +26,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(Ȳ * B'), X̄ -> mul!(X̄, Ȳ, B', true, true) @@ -48,7 +48,7 @@ function rrule( function times_pullback(Ȳ) @assert size(B, 1) === 1 # otherwise primal would have failed. return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(Ȳ * vec(B')), X̄ -> mul!(X̄, Ȳ, vec(B'), true, true) @@ -67,7 +67,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), @thunk(dot(Ȳ, B)'), InplaceableThunk( @thunk(A' * Ȳ), @@ -83,7 +83,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(A' * Ȳ), X̄ -> mul!(X̄, conj(A), Ȳ, true, true) @@ -113,7 +113,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R ∂A = last(dA_pb(unthunk(dAᵀ))) ∂B = last(dA_pb(unthunk(dBᵀ))) - (NO_FIELDS, ∂A, ∂B) + (NoTangent(), ∂A, ∂B) end return C, slash_pullback end @@ -133,7 +133,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Ā end ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B + return NoTangent(), ∂A, ∂B end return Y, backslash_pullback @@ -146,7 +146,7 @@ end function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real) Y = A/b function slash_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) + return (NoTangent(), @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) end return Y, slash_pullback end @@ -154,7 +154,7 @@ end function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real}) Y = b\A function backslash_pullback(Ȳ) - return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) + return (NoTangent(), @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) end return Y, backslash_pullback end @@ -165,7 +165,7 @@ end function rrule(::typeof(-), x::AbstractArray) function negation_pullback(ȳ) - return NO_FIELDS, InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ) + return NoTangent(), InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ) end return -x, negation_pullback end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 51da82a4d..0dac73572 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -10,7 +10,7 @@ frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') function rrule(::typeof(adjoint), z::Number) - adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ') + adjoint_pullback(ΔΩ) = (NoTangent(), ΔΩ') return (z', adjoint_pullback) end @@ -22,7 +22,7 @@ frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz)) function rrule(::typeof(real), z::Number) # add zero(z) to embed the real number in the same number type as z - real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z)) + real_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) + zero(z)) return (real(z), real_pullback) end @@ -33,7 +33,7 @@ end frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz)) function rrule(::typeof(imag), z::Complex) - imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im) + imag_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) * im) return (imag(z), imag_pullback) end @@ -45,15 +45,15 @@ function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex end function rrule(::Type{T}, z::Complex) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), Complex(ΔΩ)) return (T(z), Complex_pullback) end function rrule(::Type{T}, x::Real) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ)) return (T(x), Complex_pullback) end function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ), imag(ΔΩ)) return (T(x, y), Complex_pullback) end @@ -70,7 +70,7 @@ end function rrule(::typeof(hypot), z::Complex) Ω = hypot(z) function hypot_pullback(ΔΩ) - return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) + return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) end return (Ω, hypot_pullback) end @@ -148,7 +148,7 @@ end function rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return (x, identity_pullback) end diff --git a/src/rulesets/Base/evalpoly.jl b/src/rulesets/Base/evalpoly.jl index d2e7d3073..60d5f6bae 100644 --- a/src/rulesets/Base/evalpoly.jl +++ b/src/rulesets/Base/evalpoly.jl @@ -15,7 +15,7 @@ if VERSION ≥ v"1.4" y, ys = _evalpoly_intermediates(x, p) function evalpoly_pullback(Δy) ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) - return NO_FIELDS, ∂x, ∂p + return NoTangent(), ∂x, ∂p end return y, evalpoly_pullback end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 3bb0a79f4..0a31e7b9f 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -11,7 +11,7 @@ let ## sin function rrule(::typeof(sin), x::Number) sinx, cosx = sincos(x) - sin_pullback(Δy) = (NO_FIELDS, cosx' * Δy) + sin_pullback(Δy) = (NoTangent(), cosx' * Δy) return (sinx, sin_pullback) end @@ -23,7 +23,7 @@ let ## cos function rrule(::typeof(cos), x::Number) sinx, cosx = sincos(x) - cos_pullback(Δy) = (NO_FIELDS, -sinx' * Δy) + cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) return (cosx, cos_pullback) end @@ -75,7 +75,7 @@ let Ω = abs(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) - return (NO_FIELDS, signx * real(ΔΩ)) + return (NoTangent(), signx * real(ΔΩ)) end return Ω, abs_pullback end @@ -88,7 +88,7 @@ let function rrule(::typeof(abs2), z::Union{Real, Complex}) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) - return (NO_FIELDS, 2real(ΔΩ)*z) + return (NoTangent(), 2real(ΔΩ)*z) end return abs2(z), abs2_pullback end @@ -99,7 +99,7 @@ let end function rrule(::typeof(conj), z::Union{Real, Complex}) function conj_pullback(ΔΩ) - return (NO_FIELDS, conj(ΔΩ)) + return (NoTangent(), conj(ΔΩ)) end return conj(z), conj_pullback end @@ -115,11 +115,11 @@ let function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) - return (NO_FIELDS, Zero()) + return (NoTangent(), Zero()) end function angle_pullback(ΔΩ) Δu, Δv = reim(ΔΩ) - return (NO_FIELDS, im*Δu/ifelse(iszero(x), one(x), x)) + return (NoTangent(), im*Δu/ifelse(iszero(x), one(x), x)) # `ifelse` is applied only to denominator to ensure type-stability. end return angle(x), angle_pullback @@ -130,7 +130,7 @@ let Δu, Δv = reim(ΔΩ) # `ifelse` is applied only to denominator to ensure type-stability. n = ifelse(iszero(z), one(real(z)), abs2(z)) - return (NO_FIELDS, (-y + im*x)*Δu/n) + return (NoTangent(), (-y + im*x)*Δu/n) end return angle(z), angle_pullback end @@ -155,7 +155,7 @@ let Ω = hypot(x, y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) - return (NO_FIELDS, c * x, c * y) + return (NoTangent(), c * x, c * y) end return (Ω, hypot_pullback) end @@ -198,7 +198,7 @@ let Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, sign_pullback end @@ -214,7 +214,7 @@ let function rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) - return (NO_FIELDS, ΔΩ * y', x' * ΔΩ) + return (NoTangent(), ΔΩ * y', x' * ΔΩ) end return x * y, times_pullback end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 78208434c..e99b30036 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -21,7 +21,7 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) getindex_add! ) īnds = broadcast(_ -> DoesNotExist(), inds) - return (NO_FIELDS, x̄, īnds...) + return (NoTangent(), x̄, īnds...) end return y, getindex_pullback diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 2f5b5003a..f4dc2f8a7 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -14,7 +14,7 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} @thunk(broadcast(last∘tuple, x, ȳ)), x -> x .+= ȳ ) - return (NO_FIELDS, x̄) + return (NoTangent(), x̄) end return y, sum_pullback end @@ -51,7 +51,7 @@ function rrule( @thunk(2 .* real.(ȳ) .* x), dx -> dx .+= 2 .* real.(ȳ) .* x ) - return (NO_FIELDS, DoesNotExist(), x_thunk) + return (NoTangent(), DoesNotExist(), x_thunk) end return y, sum_abs2_pullback end diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 1ebf2a0a5..38917459b 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -10,7 +10,7 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!) - return NO_FIELDS, Δxs, DoesNotExist() + return NoTangent(), Δxs, DoesNotExist() end return ys, partialsort_pullback @@ -28,7 +28,7 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!) - return NO_FIELDS, Δxs + return NoTangent(), Δxs end return ys, sort_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 7fcfba18a..445e660a7 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -25,7 +25,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end - return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist()) + return (NoTangent(), DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist()) end return Ω, blas_dot_pullback end @@ -48,19 +48,19 @@ end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) function nrm2_pullback(ΔΩ) - return NO_FIELDS, x .* (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) + return NoTangent(), x .* (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) end return Ω, nrm2_pullback end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - nrm2_pullback(::Zero) = (NO_FIELDS, DoesNotExist(), Zero(), DoesNotExist()) + nrm2_pullback(::Zero) = (NoTangent(), DoesNotExist(), Zero(), DoesNotExist()) function nrm2_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) ∂X = scal!(n, s, blascopy!(n, X, incx, _zeros(X), incx), incx) - return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) + return (NoTangent(), DoesNotExist(), ∂X, DoesNotExist()) end return Ω, nrm2_pullback end @@ -78,19 +78,19 @@ end function rrule(::typeof(BLAS.asum), x) function asum_pullback(ΔΩ) - return (NO_FIELDS, _signcomp.(x) .* real(ΔΩ)) + return (NoTangent(), _signcomp.(x) .* real(ΔΩ)) end return BLAS.asum(x), asum_pullback end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - asum_pullback(::Zero) = (NO_FIELDS, DoesNotExist(), Zero(), DoesNotExist()) + asum_pullback(::Zero) = (NoTangent(), DoesNotExist(), Zero(), DoesNotExist()) function asum_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ)) ∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) - return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) + return (NoTangent(), DoesNotExist(), ∂X, DoesNotExist()) end return Ω, asum_pullback end @@ -135,7 +135,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x̄ -> gemv!('N', α', conj(A), ȳ, one(T), x̄) ) end - return (NO_FIELDS, DoesNotExist(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) + return (NoTangent(), DoesNotExist(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) end return y, gemv_pullback end @@ -146,7 +146,7 @@ function rrule( y, inner_pullback = rrule(gemv, tA, one(T), A, x) function gemv_pullback(Ȳ) (_, dtA, _, dA, dx) = inner_pullback(Ȳ) - return (NO_FIELDS, dtA, dA, dx) + return (NoTangent(), dtA, dA, dx) end return y, gemv_pullback end @@ -249,7 +249,7 @@ function rrule( ) end end - return (NO_FIELDS, DoesNotExist(), DoesNotExist(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) + return (NoTangent(), DoesNotExist(), DoesNotExist(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) end return C, gemm_pullback end @@ -260,7 +260,7 @@ function rrule( C, inner_pullback = rrule(gemm, tA, tB, one(T), A, B) function gemm_pullback(Ȳ) (_, dtA, dtB, _, dA, dB) = inner_pullback(Ȳ) - return (NO_FIELDS, dtA, dtB, dA, dB) + return (NoTangent(), dtA, dtB, dA, dB) end return C, gemm_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 5069f4c61..a8b1e3051 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -8,7 +8,7 @@ end function rrule(::typeof(dot), x, y) function dot_pullback(ΔΩ) - return (NO_FIELDS, @thunk(y .* ΔΩ'), @thunk(x .* ΔΩ)) + return (NoTangent(), @thunk(y .* ΔΩ'), @thunk(x .* ΔΩ)) end return dot(x, y), dot_pullback end @@ -24,9 +24,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N dx = @thunk conj(ΔΩ) .* Ay dA = @thunk ΔΩ .* x .* adjoint(y) dy = @thunk ΔΩ .* (adjoint(A) * x) - return (NO_FIELDS, dx, dA, dy) + return (NoTangent(), dx, dA, dy) end - dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + dot_pullback(::Zero) = (NoTangent(), Zero(), Zero(), Zero()) return z, dot_pullback end @@ -36,9 +36,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused dA = @thunk Diagonal(ΔΩ .* x .* conj(y)) # calculate N not N^2 elements dy = @thunk ΔΩ .* conj.(A.diag) .* x - return (NO_FIELDS, dx, dA, dy) + return (NoTangent(), dx, dA, dy) end - dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + dot_pullback(::Zero) = (NoTangent(), Zero(), Zero(), Zero()) return z, dot_pullback end @@ -54,7 +54,7 @@ end function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) Ω = cross(a, b) function cross_pullback(ΔΩ) - return (NO_FIELDS, @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) + return (NoTangent(), @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) end return Ω, cross_pullback end @@ -76,7 +76,7 @@ function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) function det_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ : Ω * ΔΩ * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, det_pullback end @@ -94,7 +94,7 @@ function rrule(::typeof(logdet), x::Union{Number, StridedMatrix{<:Number}}) Ω = logdet(x) function logdet_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ / x' : ΔΩ * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, logdet_pullback end @@ -122,7 +122,7 @@ function rrule(::typeof(logabsdet), x::AbstractMatrix) imagf = f - real(f) g = real(Δy) + imagf ∂x = g * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, logabsdet_pullback end @@ -139,7 +139,7 @@ function rrule(::typeof(tr), x) # This should really be a FillArray # see https://github.com/JuliaDiff/ChainRules.jl/issues/46 function tr_pullback(ΔΩ) - return (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1)))) + return (NoTangent(), Diagonal(fill(ΔΩ, size(x, 1)))) end return tr(x), tr_pullback end @@ -202,7 +202,7 @@ function rrule( y = pinv(x, tol) function pinv_pullback(Δy) ∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y) - return (NO_FIELDS, ∂x, Zero()) + return (NoTangent(), ∂x, Zero()) end return y, pinv_pullback end @@ -216,7 +216,7 @@ function rrule( function pinv_pullback(Δy) ∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y ∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′) - return (NO_FIELDS, ∂x, Zero()) + return (NoTangent(), ∂x, Zero()) end return y, pinv_pullback end @@ -235,7 +235,7 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} ∂A = add!!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A) ∂A = add!!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y' end - return (NO_FIELDS, ∂A) + return (NoTangent(), ∂A) end return Y, pinv_pullback end @@ -280,7 +280,7 @@ function rrule( trans = T <: Complex ? 'C' : 'T' ∂D, scale2 = LAPACK.trsyl!(trans, trans, RA, RB, ∂Y) ∂Z = rmul!(QA * (∂D * QB'), -inv(scale2)) - return NO_FIELDS, @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) + return NoTangent(), @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) end return Ω, sylvester_pullback end @@ -318,7 +318,7 @@ function rrule( ∂Y = Q' * (∂Ω * Q) ∂D, scale2 = LAPACK.trsyl!(T <: Complex ? 'C' : 'T', 'N', R, R, ∂Y) ∂Z = rmul!(Q * (∂D * Q'), -inv(scale2)) - return NO_FIELDS, @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) + return NoTangent(), @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) end return Ω, lyap_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 5f00320a6..6553f1f30 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -77,7 +77,7 @@ function rrule( F = lu(A, pivot; kwargs...) function lu_pullback(ΔF::Composite) Δfactors = ΔF.factors - Δfactors isa AbstractZero && return (NO_FIELDS, Δfactors, DoesNotExist()) + Δfactors isa AbstractZero && return (NoTangent(), Δfactors, DoesNotExist()) factors = F.factors ∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors ∂A = similar(factors) @@ -127,7 +127,7 @@ function rrule( if pivot === Val(true) ∂A = ∂A[invperm(F.p), :] end - return NO_FIELDS, ∂A, DoesNotExist() + return NoTangent(), ∂A, DoesNotExist() end return F, lu_pullback end @@ -151,10 +151,10 @@ function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:Stri elseif x === :factors Matrix(ΔY) else - return (NO_FIELDS, DoesNotExist(), DoesNotExist()) + return (NoTangent(), DoesNotExist(), DoesNotExist()) end ∂F = Composite{TF}(; factors=∂factors) - return NO_FIELDS, ∂F, DoesNotExist() + return NoTangent(), ∂F, DoesNotExist() end return getproperty(F, x), getproperty_LU_pullback end @@ -194,7 +194,7 @@ function rrule(::typeof(inv), F::LU{<:Any,<:StridedMatrix}) triu!(rdiv!(∂factors, U')) ∂factors .+= ∂L ∂F = Composite{typeof(F)}(; factors=∂factors) - return NO_FIELDS, ∂F + return NoTangent(), ∂F end return inv(F), inv_LU_pullback end @@ -208,7 +208,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) function svd_pullback(Ȳ::Composite) # `getproperty` on `Composite`s ensures we have no thunks. ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt') - return (NO_FIELDS, ∂X) + return (NoTangent(), ∂X) end return F, svd_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD elseif x === :Vt C(Vt=Ȳ,) end - return NO_FIELDS, ∂F, DoesNotExist() + return NoTangent(), ∂F, DoesNotExist() end return getproperty(F, x), getproperty_svd_pullback end @@ -301,7 +301,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ function eigen_pullback(ΔF::Composite) λ, V = F.values, F.vectors Δλ, ΔV = ΔF.values, ΔF.vectors - ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) + ΔV isa AbstractZero && Δλ isa AbstractZero && return (NoTangent(), Δλ + ΔV) if eltype(λ) <: Real && ishermitian(A) hermA = Hermitian(A) ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) @@ -319,9 +319,9 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ ∂K[diagind(∂K)] .= Δλ ∂A = mul!(∂K, V' \ ∂K, V') end - return NO_FIELDS, T <: Real ? real(∂A) : ∂A + return NoTangent(), T <: Real ? real(∂A) : ∂A end - eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + eigen_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, eigen_pullback end @@ -409,7 +409,7 @@ function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Unio function eigvals_pullback(Δλ) ∂F = Composite{typeof(F)}(values = Δλ) _, ∂A = eigen_back(∂F) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return λ, eigvals_pullback end @@ -432,7 +432,7 @@ end function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Composite) - return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist() + return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist() end return C, cholesky_pullback end @@ -441,7 +441,7 @@ function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Boo C = cholesky(A, Val(false); check=check) function cholesky_pullback(ΔC::Composite) Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) - return NO_FIELDS, Ā, DoesNotExist() + return NoTangent(), Ā, DoesNotExist() end return C, cholesky_pullback end @@ -459,7 +459,7 @@ function rrule( function cholesky_pullback(ΔC::Composite) Ā, U = _cholesky_pullback_shared_code(C, ΔC) Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) - return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist() + return NoTangent(), _symhermtype(A)(Ā), DoesNotExist() end return C, cholesky_pullback end @@ -476,7 +476,7 @@ function rrule( Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist()) + return (NoTangent(), UpperTriangular(Ā), DoesNotExist()) end return C, cholesky_pullback end @@ -507,7 +507,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} C(U=UpperTriangular(Ȳ'),) end end - return NO_FIELDS, ∂F, DoesNotExist() + return NoTangent(), ∂F, DoesNotExist() end return getproperty(F, x), getproperty_cholesky_pullback end diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 28402e5b5..f28c0a527 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -122,7 +122,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) hermX, hermX_intermediates = _matfun(exp, hermA) function exp_pullback_hermitian(ΔX) ∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates) - return NO_FIELDS, Matrix(∂hermA) + return NoTangent(), Matrix(∂hermA) end return Matrix(hermX), exp_pullback_hermitian else @@ -133,7 +133,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) # the default _matfun_frechet_adjoint! ∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')' ∂A = _matfun_frechet_adjoint!(exp, ∂X, A, X, intermediates) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return X, exp_pullback end diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 41792ad4d..e6c85bf77 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -52,9 +52,9 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) end ) ∂p = @thunk _normp_back_p(x, p, y, Δy) - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) + norm_pullback_p(::Zero) = (NoTangent(), Zero(), Zero()) return y, norm_pullback_p end @@ -74,9 +74,9 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) _norm2_back!(dx, x, y, Δy) end ) - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end - norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) + norm_pullback_2(::Zero) = (NoTangent(), Zero()) return y, norm_pullback_2 end @@ -100,9 +100,9 @@ function rrule(::typeof(norm), x::Number, p::Real) signx = x isa Real ? sign(x) : x * pinv(y) signx * real(Δy) end - return (NO_FIELDS, ∂x, Zero()) + return (NoTangent(), ∂x, Zero()) end - norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) + norm_pullback(::Zero) = (NoTangent(), Zero(), Zero()) return y, norm_pullback end @@ -115,9 +115,9 @@ function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) ∂p = @thunk _normp_back_p(x, p, y, Δy) - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - normp_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) + normp_pullback(::Zero) = (NoTangent(), Zero(), Zero()) return y, normp_pullback end @@ -158,15 +158,15 @@ end function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) - normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero()) + normMinusInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normMinusInf_pullback(::Zero) = (NoTangent(), Zero()) return y, normMinusInf_pullback end function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) - normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normInf_pullback(::Zero) = (NO_FIELDS, Zero()) + normInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normInf_pullback(::Zero) = (NoTangent(), Zero()) return y, normInf_pullback end @@ -189,11 +189,11 @@ end function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) y = LinearAlgebra.norm1(x) - norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + norm1_pullback(Δy) = (NoTangent(), InplaceableThunk( @thunk(_norm1_back(x, y, Δy)), dx -> _norm1_back!(dx, x, y, Δy), )) - norm1_pullback(::Zero) = (NO_FIELDS, Zero()) + norm1_pullback(::Zero) = (NoTangent(), Zero()) return y, norm1_pullback end @@ -221,11 +221,11 @@ end function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) y = LinearAlgebra.norm2(x) - norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + norm2_pullback(Δy) = (NoTangent(), InplaceableThunk( @thunk(_norm2_back(x, y, Δy)), dx -> _norm2_back!(dx, x, y, Δy), )) - norm2_pullback(::Zero) = (NO_FIELDS, Zero()) + norm2_pullback(::Zero) = (NoTangent(), Zero()) return y, norm2_pullback end @@ -261,9 +261,9 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) ∂nrm = -dot(y, Δy) * invnrm (_, ∂xnorm, ∂p) = inner_pullback(∂nrm) ∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) + normalize_pullback(::Zero) = (NoTangent(), Zero(), Zero()) return y, normalize_pullback end @@ -274,8 +274,8 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}) LinearAlgebra.__normalize!(y, nrm) function normalize_pullback(Δy) ∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm) - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end - normalize_pullback(::Zero) = (NO_FIELDS, Zero()) + normalize_pullback(::Zero) = (NoTangent(), Zero()) return y, normalize_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 04f359acb..89096bece 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -10,7 +10,7 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr function slash_pullback(Ȳ) ∂A = @thunk Ȳ / B' ∂B = @thunk _unionall_wrapper(T)(-Y' * (Ȳ / B')) - return (NO_FIELDS, ∂A, ∂B) + return (NoTangent(), ∂A, ∂B) end return Y, slash_pullback end @@ -20,7 +20,7 @@ function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMa function backslash_pullback(Ȳ) ∂A = @thunk _unionall_wrapper(T)(-(A' \ Ȳ) * Y') ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B + return NoTangent(), ∂A, ∂B end return Y, backslash_pullback end @@ -31,42 +31,42 @@ end function rrule(::Type{<:Diagonal}, d::AbstractVector) function Diagonal_pullback(ȳ::AbstractMatrix) - return (NO_FIELDS, diag(ȳ)) + return (NoTangent(), diag(ȳ)) end function Diagonal_pullback(ȳ::Composite) # TODO: Assert about the primal type in the Composite, It should be Diagonal # infact it should be exactly the type of `Diagonal(d)` # but right now Zygote loses primal type information so we can't use it. # See https://github.com/FluxML/Zygote.jl/issues/603 - return (NO_FIELDS, ȳ.diag) + return (NoTangent(), ȳ.diag) end return Diagonal(d), Diagonal_pullback end function rrule(::typeof(diag), A::AbstractMatrix) function diag_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ)) + return (NoTangent(), Diagonal(ȳ)) end return diag(A), diag_pullback end if VERSION ≥ v"1.3" function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) function diag_pullback(ȳ) - return (NO_FIELDS, diagm(size(A)..., k => ȳ), DoesNotExist()) + return (NoTangent(), diagm(size(A)..., k => ȳ), DoesNotExist()) end return diag(A, k), diag_pullback end function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, DoesNotExist(), DoesNotExist(), _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), DoesNotExist(), DoesNotExist(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(m, n, kv...), diagm_pullback end end function rrule(::typeof(diagm), kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(kv...), diagm_pullback end @@ -79,7 +79,7 @@ end function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) + return (NoTangent(), @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) end return D * V, times_pullback end @@ -89,26 +89,26 @@ end ##### function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) - Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + Adjoint_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return Adjoint(A), Adjoint_pullback end function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) - Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + Adjoint_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + Adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return Adjoint(A), Adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) - adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + adjoint_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) - adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + adjoint_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return adjoint(A), adjoint_pullback end @@ -117,26 +117,26 @@ end ##### function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) - Transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) + Transpose_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + Transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), Transpose(ȳ)) return Transpose(A), Transpose_pullback end function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) - Transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) + Transpose_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + Transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(Transpose(ȳ))) return Transpose(A), Transpose_pullback end function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) - transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) + transpose_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), transpose(ȳ)) return transpose(A), transpose_pullback end function rrule(::typeof(transpose), A::AbstractVector{<:Number}) - transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) + transpose_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(transpose(ȳ))) return transpose(A), transpose_pullback end @@ -146,40 +146,40 @@ end function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) function UpperTriangular_pullback(ȳ) - return (NO_FIELDS, Matrix(ȳ)) + return (NoTangent(), Matrix(ȳ)) end return UpperTriangular(A), UpperTriangular_pullback end function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) function LowerTriangular_pullback(ȳ) - return (NO_FIELDS, Matrix(ȳ)) + return (NoTangent(), Matrix(ȳ)) end return LowerTriangular(A), LowerTriangular_pullback end function rrule(::typeof(triu), A::AbstractMatrix, k::Integer) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ, k), DoesNotExist()) + return (NoTangent(), triu(ȳ, k), DoesNotExist()) end return triu(A, k), triu_pullback end function rrule(::typeof(triu), A::AbstractMatrix) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ)) + return (NoTangent(), triu(ȳ)) end return triu(A), triu_pullback end function rrule(::typeof(tril), A::AbstractMatrix, k::Integer) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ, k), DoesNotExist()) + return (NoTangent(), tril(ȳ, k), DoesNotExist()) end return tril(A, k), tril_pullback end function rrule(::typeof(tril), A::AbstractMatrix) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ)) + return (NoTangent(), tril(ȳ)) end return tril(A), tril_pullback end @@ -191,7 +191,7 @@ function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) y = det(X) s = conj!(y ./ _diag_view(X)) function det_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ .* s)) + return (NoTangent(), Diagonal(ȳ .* s)) end return y, det_pullback end @@ -200,7 +200,7 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) y = logdet(X) s = conj!(one(eltype(X)) ./ _diag_view(X)) function logdet_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ .* s)) + return (NoTangent(), Diagonal(ȳ .* s)) end return y, logdet_pullback end diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index d4e6d7d68..7178c337b 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -9,7 +9,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) @inline function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), DoesNotExist()) + return (NoTangent(), _symherm_back(typeof(Ω), ΔΩ, uplo), DoesNotExist()) end return Ω, HermOrSym_pullback end @@ -27,7 +27,7 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} uplo = A.uplo ∂A = T∂A(_symherm_back(typeof(A), ΔΩ, Symbol(uplo)), uplo) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return TM(A), Matrix_pullback end @@ -133,9 +133,9 @@ function rrule( Δλ, ΔU = ΔF.values, ΔF.vectors ΔU = ΔU isa AbstractZero ? ΔU : copy(ΔU) ∂A = eigen_rev!(A, λ, U, Δλ, ΔU) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end - eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + eigen_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, eigen_pullback end @@ -226,7 +226,7 @@ function rrule( function eigvals_pullback(Δλ) ∂F = Composite{typeof(F)}(values = Δλ) _, ∂A = eigen_back(∂F) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return λ, eigvals_pullback end @@ -251,9 +251,9 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.Bla ∂U = ΔF.U .+ (ΔF.V .+ ΔF.Vt') .* c' end ∂A = eigen_rev!(A, λ, U, ∂λ, ∂U) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end - svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + svd_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, svd_pullback end @@ -286,9 +286,9 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS invpermute!(∂λ, p) ∂λ .*= sign.(λ) _, ∂A = back(∂λ) - return NO_FIELDS, unthunk(∂A) + return NoTangent(), unthunk(∂A) end - svdvals_pullback(ΔS::AbstractZero) = (NO_FIELDS, ΔS) + svdvals_pullback(ΔS::AbstractZero) = (NoTangent(), ΔS) return S, svdvals_pullback end @@ -324,7 +324,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates) # the cotangent of Hermitian A should be Hermitian ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, $(Symbol(func, :_pullback)) end @@ -351,7 +351,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo) Y = (sinA, cosA) function sincos_pullback((ΔsinA, ΔcosA)::Composite) - ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA + ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NoTangent(), ΔsinA + ΔcosA if eltype(A) <: Real ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) end @@ -369,7 +369,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U')) end ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, sincos_pullback end diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index 8da0b1f22..29997c2a1 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -2,7 +2,7 @@ frule(Δargs, ::Type{MersenneTwister}, args...) = MersenneTwister(args...), Zero function rrule(::Type{MersenneTwister}, args...) function MersenneTwister_pullback(ΔΩ) - return (NO_FIELDS, map(_ -> Zero(), args)...) + return (NoTangent(), map(_ -> Zero(), args)...) end return MersenneTwister(args...), MersenneTwister_pullback end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 6b15b10fa..6f3f72037 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -14,7 +14,7 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) function mean_pullback(ȳ) _, ∂sum_x = sum_pullback(ȳ) ∂x = extern(∂sum_x) / n - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return y_sum / n, mean_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a3ab3ca0d..6abd7dfc2 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -165,12 +165,12 @@ end @testset "type" begin - @test frule((NO_FIELDS, DoesNotExist(), DoesNotExist()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing + @test frule((NoTangent(), DoesNotExist(), DoesNotExist()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing @test rrule(typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing end @testset "Logging" begin - @test frule((NO_FIELDS, DoesNotExist(), DoesNotExist()), Base.depwarn, "message", :f) !== nothing + @test frule((NoTangent(), DoesNotExist(), DoesNotExist()), Base.depwarn, "message", :f) !== nothing @test rrule(Base.depwarn, "message", :f) !== nothing end end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index f4b88fead..f398de7ea 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -187,12 +187,12 @@ end @test X̄_vectors_ad ≈ j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1] rtol=1e-4 F̄ = CT(values = λ̄, vectors = V̄) s̄elf, X̄_ad = @inferred back(F̄) - @test s̄elf === NO_FIELDS + @test s̄elf === NoTangent() X̄_fd = j′vp(_fdm, asnt ∘ eigen, F̄, X)[1] @test X̄_ad ≈ X̄_fd rtol=1e-4 - @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) === (NoTangent(), Zero()) F̄zero = CT(values = Zero(), vectors = Zero()) - @test @inferred(back(F̄zero)) === (NO_FIELDS, Zero()) + @test @inferred(back(F̄zero)) === (NoTangent(), Zero()) T <: Real && @testset "cotangent is real when input is" begin X = randn(T, n, n) @@ -286,7 +286,7 @@ end end ∂self, ∂A = @inferred back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) ∂A_fd = j′vp(_fdm, f_stable, ∂F_stable, A)[1] @test ∂A ≈ ∂A_fd @@ -318,7 +318,7 @@ end λ, back = rrule(eigvals, randn(T, n, n)) _, X̄ = @inferred back(rand_tangent(λ)) - @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) === (NoTangent(), Zero()) T <: Real && @testset "cotangent is real when input is" begin @test eltype(X̄) <: Real @@ -343,10 +343,10 @@ end λ_ad, back = rrule(eigvals, A) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂A = @inferred back(Δλ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) end end end @@ -375,7 +375,7 @@ end Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NO_FIELDS + @test dself === NoTangent() @test dp === DoesNotExist() # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index bb2d20bfd..88c2a80a8 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -159,7 +159,7 @@ test_rrule(norm, randn(T), p) _, back = rrule(norm, randn(T), p) - @test back(Zero()) == (NO_FIELDS, Zero(), Zero()) + @test back(Zero()) == (NoTangent(), Zero(), Zero()) end @testset "p = 0" begin p = 0.0 @@ -172,8 +172,8 @@ @test iszero(ẏ) y_rev, back = rrule(norm, x, p) @test y_rev == y - @test back(ȳ) == (NO_FIELDS, zero(x), Zero()) - @test back(Zero()) == (NO_FIELDS, Zero(), Zero()) + @test back(ȳ) == (NoTangent(), zero(x), Zero()) + @test back(Zero()) == (NoTangent(), Zero(), Zero()) end end end @@ -185,12 +185,12 @@ end @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) test_rrule(normalize, x) - @test rrule(normalize, x)[2](Zero()) === (NO_FIELDS, Zero()) + @test rrule(normalize, x)[2](Zero()) === (NoTangent(), Zero()) end @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable x = randn(T, 3) test_rrule(normalize, x, p) - @test rrule(normalize, x, p)[2](Zero()) === (NO_FIELDS, Zero(), Zero()) + @test rrule(normalize, x, p)[2](Zero()) === (NoTangent(), Zero(), Zero()) end end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 4bf3fbdda..fee31007b 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -23,9 +23,9 @@ # TODO: replace this with a `rrule_test` once we have that working # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24 res, pb = rrule(Diagonal, [1, 4]) - @test pb(10*res) == (NO_FIELDS, [10, 40]) + @test pb(10*res) == (NoTangent(), [10, 40]) comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal - @test pb(comp) == (NO_FIELDS, [10, 40]) + @test pb(comp) == (NoTangent(), [10, 40]) end @testset "dot(x, ::Diagonal, y)" begin N = 4 @@ -57,7 +57,7 @@ y, back = rrule(diagm, ps...) @test y == diagm(ps...) ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) ∂px = unthunk(∂px) @@ -76,7 +76,7 @@ y, back = rrule(diagm, M, N, ps...) @test y == diagm(M, N, ps...) ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂M === DoesNotExist() @test ∂N === DoesNotExist() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 30943bb96..a873fe611 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -133,7 +133,7 @@ end ∂self, ∂symA = @inferred back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @@ -141,8 +141,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) - @test @inferred(back(CT())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) + @test @inferred(back(CT())) == (NoTangent(), Zero()) end # when value used to determine phase convention is low, the usual derivatives @@ -207,7 +207,7 @@ λ_ad, back = @inferred rrule(eigvals, symA) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂symA = @inferred back(Δλ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo @@ -215,7 +215,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) end end end @@ -264,7 +264,7 @@ VERSION ≥ v"1.6.0-DEV.1686" && @inferred back(∂F) ∂self, ∂symA = back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @@ -272,8 +272,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) - @test @inferred(back(CT())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) + @test @inferred(back(CT())) == (NoTangent(), Zero()) end @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), @@ -287,7 +287,7 @@ S_ad, back = @inferred rrule(svdvals, symA) @test S_ad ≈ S # inexact because rrule uses svd not svdvals ∂self, ∂symA = @inferred back(ΔS) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo @@ -295,7 +295,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) end end @@ -412,7 +412,7 @@ @test typeof(Y_ad) === typeof(Y) hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo ∂self, ∂A = @inferred back(ΔY) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A.uplo == A.uplo # check pullback composes correctly @@ -496,11 +496,11 @@ ΔY = Composite{typeof(Y)}(ΔsinA, ΔcosA) ∂self, ∂A = @inferred back(ΔY) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A ≈ rrule(sin, A)[2](ΔsinA)[2] + rrule(cos, A)[2](ΔcosA)[2] ΔY2 = Composite{typeof(Y)}(Zero(), Zero()) - @test @inferred(back(ΔY2)) === (NO_FIELDS, Zero()) + @test @inferred(back(ΔY2)) === (NoTangent(), Zero()) ΔY3 = Composite{typeof(Y)}(ΔsinA, Zero()) @test @inferred(back(ΔY3)) == rrule(sin, A)[2](ΔsinA) From adfb5affd318df136e3fbb8ef861c6858ab3d2e7 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 14:31:58 +0100 Subject: [PATCH 2/9] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c12e8c0eb..1409d13a4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.65" +version = "0.7.68" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From f878883c47ba3dffd10626904f472e25a9ec8ef1 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 12:39:37 +0100 Subject: [PATCH 3/9] replace NO_FIELDS by NoTangent() --- src/rulesets/Base/array.jl | 12 +++--- src/rulesets/Base/arraymath.jl | 6 +-- src/rulesets/Base/fastmath_able.jl | 2 +- src/rulesets/Base/indexing.jl | 2 +- src/rulesets/Base/mapreduce.jl | 4 +- src/rulesets/Base/sort.jl | 2 +- src/rulesets/LinearAlgebra/blas.jl | 14 +++---- src/rulesets/LinearAlgebra/dense.jl | 8 ++-- src/rulesets/LinearAlgebra/factorization.jl | 22 +++++------ src/rulesets/LinearAlgebra/norm.jl | 26 ++++++------- src/rulesets/LinearAlgebra/structured.jl | 40 ++++++++++---------- src/rulesets/LinearAlgebra/symmetric.jl | 8 ++-- src/rulesets/Random/random.jl | 2 +- test/rulesets/Base/base.jl | 4 +- test/rulesets/LinearAlgebra/factorization.jl | 10 ++--- test/rulesets/LinearAlgebra/norm.jl | 10 ++--- test/rulesets/LinearAlgebra/structured.jl | 6 +-- test/rulesets/LinearAlgebra/symmetric.jl | 14 +++---- test/rulesets/Random/random.jl | 2 +- 19 files changed, 97 insertions(+), 97 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 8a1a88f63..b1e1207e9 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -5,7 +5,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) A_dims = size(A) function reshape_pullback(Ȳ) - return (NO_FIELDS, reshape(Ȳ, A_dims), NoTangent()) + return (NoTangent(), reshape(Ȳ, A_dims), NoTangent()) end return reshape(A, dims), reshape_pullback end @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) ∂A = reshape(Ȳ, A_dims) ∂dims = broadcast(_ -> NoTangent(), dims) - return (NO_FIELDS, ∂A, ∂dims...) + return (NoTangent(), ∂A, ∂dims...) end return reshape(A, dims...), reshape_pullback end @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[:, pre:post] end - return (NO_FIELDS, NoTangent(), ∂As) + return (NoTangent(), NoTangent(), ∂As) end return reduce(hcat, As), reduce_hcat_pullback end @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[pre:post, :] end - return (NO_FIELDS, NoTangent(), ∂As) + return (NoTangent(), NoTangent(), ∂As) end return reduce(vcat, As), reduce_vcat_pullback end @@ -92,14 +92,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), NoTangent()) + return (NoTangent(), sum(Ȳ), NoTangent()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), ntuple(_->NoTangent(), length(dims))...) + return (NoTangent(), sum(Ȳ), ntuple(_->NoTangent(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index d5e13d8eb..5b9baf716 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -127,7 +127,7 @@ function rrule( dz -> sum!(dz, Ȳ; init=false) ) end - (NO_FIELDS, matmul..., addon) + (NoTangent(), matmul..., addon) end return muladd(A, B, z), muladd_pullback_1 end @@ -148,7 +148,7 @@ function rrule( @thunk(ut' .* dy), dv -> dv .+= ut' .* dy ) - (NO_FIELDS, ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy) + (NoTangent(), ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy) end return muladd(ut, v, z), muladd_pullback_2 end @@ -175,7 +175,7 @@ function rrule( dz -> sum!(dz, Ȳ; init=false) ) end - (NO_FIELDS, proj..., addon) + (NoTangent(), proj..., addon) end return muladd(u, vt, z), muladd_pullback_3 end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 2c5fa7d4c..2ddf1b4b1 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -115,7 +115,7 @@ let function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) - return (NO_FIELDS, ZeroTangent()) + return (NoTangent(), ZeroTangent()) end function angle_pullback(ΔΩ) Δu, Δv = reim(ΔΩ) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index bc945f8f6..5f8610ff5 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -21,7 +21,7 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) getindex_add! ) īnds = broadcast(_ -> NoTangent(), inds) - return (NO_FIELDS, x̄, īnds...) + return (NoTangent(), x̄, īnds...) end return y, getindex_pullback diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index d44e98608..33d9fa563 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -51,7 +51,7 @@ function rrule( @thunk(2 .* real.(ȳ) .* x), dx -> dx .+= 2 .* real.(ȳ) .* x ) - return (NO_FIELDS, NoTangent(), x_thunk) + return (NoTangent(), NoTangent(), x_thunk) end return y, sum_abs2_pullback end @@ -85,7 +85,7 @@ function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:Commutativ dx .+= conj.(y ./ x) .* dy end ) - return (NO_FIELDS, x_thunk) + return (NoTangent(), x_thunk) end return y, prod_pullback end diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 8372995bd..c803dedb0 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -10,7 +10,7 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!) - return NO_FIELDS, Δxs, NoTangent() + return NoTangent(), Δxs, NoTangent() end return ys, partialsort_pullback diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index d3eda6738..d40e39701 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -25,7 +25,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end - return (NO_FIELDS, NoTangent(), ∂X, NoTangent(), ∂Y, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent(), ∂Y, NoTangent()) end return Ω, blas_dot_pullback end @@ -55,12 +55,12 @@ end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - nrm2_pullback(::ZeroTangent) = (NO_FIELDS, NoTangent(), ZeroTangent(), NoTangent()) + nrm2_pullback(::ZeroTangent) = (NoTangent(), NoTangent(), ZeroTangent(), NoTangent()) function nrm2_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) ∂X = scal!(n, s, blascopy!(n, X, incx, _zeros(X), incx), incx) - return (NO_FIELDS, NoTangent(), ∂X, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent()) end return Ω, nrm2_pullback end @@ -85,12 +85,12 @@ end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - asum_pullback(::ZeroTangent) = (NO_FIELDS, NoTangent(), ZeroTangent(), NoTangent()) + asum_pullback(::ZeroTangent) = (NoTangent(), NoTangent(), ZeroTangent(), NoTangent()) function asum_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ)) ∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) - return (NO_FIELDS, NoTangent(), ∂X, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent()) end return Ω, asum_pullback end @@ -135,7 +135,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x̄ -> gemv!('N', α', conj(A), ȳ, one(T), x̄) ) end - return (NO_FIELDS, NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) + return (NoTangent(), NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) end return y, gemv_pullback end @@ -249,7 +249,7 @@ function rrule( ) end end - return (NO_FIELDS, NoTangent(), NoTangent(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) + return (NoTangent(), NoTangent(), NoTangent(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) end return C, gemm_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 4bb4b73bb..30769eca3 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -26,7 +26,7 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N dy = @thunk ΔΩ .* (adjoint(A) * x) return (NoTangent(), dx, dA, dy) end - dot_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent(), ZeroTangent()) + dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) return z, dot_pullback end @@ -38,7 +38,7 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} dy = @thunk ΔΩ .* conj.(A.diag) .* x return (NoTangent(), dx, dA, dy) end - dot_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent(), ZeroTangent()) + dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) return z, dot_pullback end @@ -202,7 +202,7 @@ function rrule( y = pinv(x, tol) function pinv_pullback(Δy) ∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y) - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end return y, pinv_pullback end @@ -216,7 +216,7 @@ function rrule( function pinv_pullback(Δy) ∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y ∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′) - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end return y, pinv_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index abbd975ee..109b16f80 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -77,7 +77,7 @@ function rrule( F = lu(A, pivot; kwargs...) function lu_pullback(ΔF::Tangent) Δfactors = ΔF.factors - Δfactors isa AbstractZero && return (NO_FIELDS, Δfactors, NoTangent()) + Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent()) factors = F.factors ∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors ∂A = similar(factors) @@ -127,7 +127,7 @@ function rrule( if pivot === Val(true) ∂A = ∂A[invperm(F.p), :] end - return NO_FIELDS, ∂A, NoTangent() + return NoTangent(), ∂A, NoTangent() end return F, lu_pullback end @@ -151,10 +151,10 @@ function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:Stri elseif x === :factors Matrix(ΔY) else - return (NO_FIELDS, NoTangent(), NoTangent()) + return (NoTangent(), NoTangent(), NoTangent()) end ∂F = Tangent{TF}(; factors=∂factors) - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_LU_pullback end @@ -194,7 +194,7 @@ function rrule(::typeof(inv), F::LU{<:Any,<:StridedMatrix}) triu!(rdiv!(∂factors, U')) ∂factors .+= ∂L ∂F = Tangent{typeof(F)}(; factors=∂factors) - return NO_FIELDS, ∂F + return NoTangent(), ∂F end return inv(F), inv_LU_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD elseif x === :Vt C(Vt=Ȳ,) end - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_svd_pullback end @@ -432,7 +432,7 @@ end function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Tangent) - return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() + return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() end return C, cholesky_pullback end @@ -441,7 +441,7 @@ function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Boo C = cholesky(A, Val(false); check=check) function cholesky_pullback(ΔC::Tangent) Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) - return NO_FIELDS, Ā, NoTangent() + return NoTangent(), Ā, NoTangent() end return C, cholesky_pullback end @@ -459,7 +459,7 @@ function rrule( function cholesky_pullback(ΔC::Tangent) Ā, U = _cholesky_pullback_shared_code(C, ΔC) Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) - return NO_FIELDS, _symhermtype(A)(Ā), NoTangent() + return NoTangent(), _symhermtype(A)(Ā), NoTangent() end return C, cholesky_pullback end @@ -476,7 +476,7 @@ function rrule( Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - return (NO_FIELDS, UpperTriangular(Ā), NoTangent()) + return (NoTangent(), UpperTriangular(Ā), NoTangent()) end return C, cholesky_pullback end @@ -507,7 +507,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} C(U=UpperTriangular(Ȳ'),) end end - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_cholesky_pullback end diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 8a120c28f..4c90cbfe1 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -54,7 +54,7 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NoTangent(), ∂x, ∂p) end - norm_pullback_p(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + norm_pullback_p(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, norm_pullback_p end @@ -76,7 +76,7 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) ) return (NoTangent(), ∂x) end - norm_pullback_2(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm_pullback_2(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm_pullback_2 end @@ -100,9 +100,9 @@ function rrule(::typeof(norm), x::Number, p::Real) signx = x isa Real ? sign(x) : x * pinv(y) signx * real(Δy) end - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end - norm_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + norm_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, norm_pullback end @@ -117,7 +117,7 @@ function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NoTangent(), ∂x, ∂p) end - normp_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + normp_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, normp_pullback end @@ -158,15 +158,15 @@ end function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) - normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normMinusInf_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normMinusInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normMinusInf_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normMinusInf_pullback end function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) - normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normInf_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normInf_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normInf_pullback end @@ -193,7 +193,7 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) @thunk(_norm1_back(x, y, Δy)), dx -> _norm1_back!(dx, x, y, Δy), )) - norm1_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm1_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm1_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) @thunk(_norm2_back(x, y, Δy)), dx -> _norm2_back!(dx, x, y, Δy), )) - norm2_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm2_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm2_pullback end @@ -263,7 +263,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) ∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm return (NoTangent(), ∂x, ∂p) end - normalize_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, normalize_pullback end @@ -276,6 +276,6 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}) ∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm) return (NoTangent(), ∂x) end - normalize_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normalize_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 7687baaa5..5203f46c1 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -52,14 +52,14 @@ end if VERSION ≥ v"1.3" function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) function diag_pullback(ȳ) - return (NO_FIELDS, diagm(size(A)..., k => ȳ), NoTangent()) + return (NoTangent(), diagm(size(A)..., k => ȳ), NoTangent()) end return diag(A, k), diag_pullback end function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(m, n, kv...), diagm_pullback end @@ -89,26 +89,26 @@ end ##### function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) - Adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + Adjoint_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return Adjoint(A), Adjoint_pullback end function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) - Adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + Adjoint_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + Adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return Adjoint(A), Adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) - adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + adjoint_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) - adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + adjoint_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return adjoint(A), adjoint_pullback end @@ -117,26 +117,26 @@ end ##### function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) - Transpose_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) + Transpose_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + Transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), Transpose(ȳ)) return Transpose(A), Transpose_pullback end function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) - Transpose_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) + Transpose_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + Transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(Transpose(ȳ))) return Transpose(A), Transpose_pullback end function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) - transpose_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) + transpose_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), transpose(ȳ)) return transpose(A), transpose_pullback end function rrule(::typeof(transpose), A::AbstractVector{<:Number}) - transpose_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) + transpose_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(transpose(ȳ))) return transpose(A), transpose_pullback end @@ -160,7 +160,7 @@ end function rrule(::typeof(triu), A::AbstractMatrix, k::Integer) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ, k), NoTangent()) + return (NoTangent(), triu(ȳ, k), NoTangent()) end return triu(A, k), triu_pullback end @@ -173,7 +173,7 @@ end function rrule(::typeof(tril), A::AbstractMatrix, k::Integer) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ, k), NoTangent()) + return (NoTangent(), tril(ȳ, k), NoTangent()) end return tril(A, k), tril_pullback end diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 2ecd73a79..6576a1820 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -9,7 +9,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) @inline function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) + return (NoTangent(), _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) end return Ω, HermOrSym_pullback end @@ -324,7 +324,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates) # the cotangent of Hermitian A should be Hermitian ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, $(Symbol(func, :_pullback)) end @@ -351,7 +351,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo) Y = (sinA, cosA) function sincos_pullback((ΔsinA, ΔcosA)::Tangent) - ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA + ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NoTangent(), ΔsinA + ΔcosA if eltype(A) <: Real ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) end @@ -369,7 +369,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U')) end ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, sincos_pullback end diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index c68741987..20fa10971 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -2,7 +2,7 @@ frule(Δargs, ::Type{MersenneTwister}, args...) = MersenneTwister(args...), Zero function rrule(::Type{MersenneTwister}, args...) function MersenneTwister_pullback(ΔΩ) - return (NO_FIELDS, map(_ -> ZeroTangent(), args)...) + return (NoTangent(), map(_ -> ZeroTangent(), args)...) end return MersenneTwister(args...), MersenneTwister_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 74a634238..91203fb86 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -165,12 +165,12 @@ end @testset "type" begin - @test frule((NO_FIELDS, NoTangent(), NoTangent()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing + @test frule((NoTangent(), NoTangent(), NoTangent()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing @test rrule(typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing end @testset "Logging" begin - @test frule((NO_FIELDS, NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing + @test frule((NoTangent(), NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing @test rrule(Base.depwarn, "message", :f) !== nothing end end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 72a47cb77..397f1622f 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -190,9 +190,9 @@ end @test s̄elf === NoTangent() X̄_fd = j′vp(_fdm, asnt ∘ eigen, F̄, X)[1] @test X̄_ad ≈ X̄_fd rtol=1e-4 - @test @inferred(back(ZeroTangent())) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) === (NoTangent(), ZeroTangent()) F̄zero = CT(values = ZeroTangent(), vectors = ZeroTangent()) - @test @inferred(back(F̄zero)) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(F̄zero)) === (NoTangent(), ZeroTangent()) T <: Real && @testset "cotangent is real when input is" begin X = randn(T, n, n) @@ -318,7 +318,7 @@ end λ, back = rrule(eigvals, randn(T, n, n)) _, X̄ = @inferred back(rand_tangent(λ)) - @test @inferred(back(ZeroTangent())) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) === (NoTangent(), ZeroTangent()) T <: Real && @testset "cotangent is real when input is" begin @test eltype(X̄) <: Real @@ -346,7 +346,7 @@ end @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end end @@ -375,7 +375,7 @@ end Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NO_FIELDS + @test dself === NoTangent() @test dp === NoTangent() # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index ebd7b0908..4420b34c7 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -159,7 +159,7 @@ test_rrule(norm, randn(T), p) _, back = rrule(norm, randn(T), p) - @test back(ZeroTangent()) == (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test back(ZeroTangent()) == (NoTangent(), ZeroTangent(), ZeroTangent()) end @testset "p = 0" begin p = 0.0 @@ -172,8 +172,8 @@ @test iszero(ẏ) y_rev, back = rrule(norm, x, p) @test y_rev == y - @test back(ȳ) == (NO_FIELDS, zero(x), ZeroTangent()) - @test back(ZeroTangent()) == (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test back(ȳ) == (NoTangent(), zero(x), ZeroTangent()) + @test back(ZeroTangent()) == (NoTangent(), ZeroTangent(), ZeroTangent()) end end end @@ -185,12 +185,12 @@ end @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) test_rrule(normalize, x) - @test rrule(normalize, x)[2](ZeroTangent()) === (NO_FIELDS, ZeroTangent()) + @test rrule(normalize, x)[2](ZeroTangent()) === (NoTangent(), ZeroTangent()) end @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable x = randn(T, 3) test_rrule(normalize, x, p) - @test rrule(normalize, x, p)[2](ZeroTangent()) === (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test rrule(normalize, x, p)[2](ZeroTangent()) === (NoTangent(), ZeroTangent(), ZeroTangent()) end end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 7d574db09..fb4335975 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -23,9 +23,9 @@ # TODO: replace this with a `rrule_test` once we have that working # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24 res, pb = rrule(Diagonal, [1, 4]) - @test pb(10*res) == (NO_FIELDS, [10, 40]) + @test pb(10*res) == (NoTangent(), [10, 40]) comp = Tangent{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal - @test pb(comp) == (NO_FIELDS, [10, 40]) + @test pb(comp) == (NoTangent(), [10, 40]) end @testset "dot(x, ::Diagonal, y)" begin N = 4 @@ -76,7 +76,7 @@ y, back = rrule(diagm, M, N, ps...) @test y == diagm(M, N, ps...) ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂M === NoTangent() @test ∂N === NoTangent() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index b748eed02..27e13c3ee 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -141,8 +141,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) - @test @inferred(back(CT())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + @test @inferred(back(CT())) == (NoTangent(), ZeroTangent()) end # when value used to determine phase convention is low, the usual derivatives @@ -215,7 +215,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end end @@ -272,8 +272,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) - @test @inferred(back(CT())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + @test @inferred(back(CT())) == (NoTangent(), ZeroTangent()) end @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), @@ -295,7 +295,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end @@ -500,7 +500,7 @@ @test ∂A ≈ rrule(sin, A)[2](ΔsinA)[2] + rrule(cos, A)[2](ΔcosA)[2] ΔY2 = Tangent{typeof(Y)}(ZeroTangent(), ZeroTangent()) - @test @inferred(back(ΔY2)) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ΔY2)) === (NoTangent(), ZeroTangent()) ΔY3 = Tangent{typeof(Y)}(ΔsinA, ZeroTangent()) @test @inferred(back(ΔY3)) == rrule(sin, A)[2](ΔsinA) diff --git a/test/rulesets/Random/random.jl b/test/rulesets/Random/random.jl index 66aee3631..bde25c7b3 100644 --- a/test/rulesets/Random/random.jl +++ b/test/rulesets/Random/random.jl @@ -14,7 +14,7 @@ Random.rand(d::NormalDistribution) = d.μ + d.σ*randn() rng, pb = rrule(MersenneTwister) @test rng isa MersenneTwister - @test first(pb(10)) isa typeof(NO_FIELDS) + @test first(pb(10)) isa typeof(NoTangent()) end @testset "unary" begin rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123) From 695dc939d818c2c1ee8f915ba439229d6208b9cb Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 12:40:51 +0100 Subject: [PATCH 4/9] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 85bf06c2a..2f6e5b554 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.70" +version = "0.7.71" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 23a8ee9144f922d9799ff07d3c6b549e8352106c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 12:44:03 +0100 Subject: [PATCH 5/9] change compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2f6e5b554..b62d7981d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.9.44" +ChainRulesCore = "0.10" ChainRulesTestUtils = "0.6.8" Compat = "3.30" FiniteDifferences = "0.12.8" From 40ed6d2e24eac19ac4ecc9f9f8107840b0a18b58 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 13:41:02 +0100 Subject: [PATCH 6/9] breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b62d7981d..9471c5559 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.71" +version = "0.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 0926b3c474204a78c6a9ea69d80fa50dcde21847 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 16:00:53 +0100 Subject: [PATCH 7/9] bump ChainRulesTestUtils compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9471c5559..b5200218b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.10" -ChainRulesTestUtils = "0.6.8" +ChainRulesTestUtils = "0.7" Compat = "3.30" FiniteDifferences = "0.12.8" Reexport = "0.2, 1" From 145e0d23cdd471c77a76db745e9a616c5dec8542 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 2 Jun 2021 12:22:41 +0100 Subject: [PATCH 8/9] remove all old differential names --- Project.toml | 1 + src/ChainRules.jl | 1 - src/rulesets/Base/arraymath.jl | 6 +++--- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index b5200218b..fe8528dae 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 94289d502..203031ce1 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -2,7 +2,6 @@ module ChainRules using Reexport @reexport using ChainRulesCore -export Zero, DoesNotExist, Composite, AbstractDifferential using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using Compat diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 5b9baf716..a0c14a4c0 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -118,7 +118,7 @@ function rrule( ) ) addon = if z isa Bool - DoesNotExist() + NoTangent() elseif z isa Number @thunk(sum(Ȳ)) else @@ -148,7 +148,7 @@ function rrule( @thunk(ut' .* dy), dv -> dv .+= ut' .* dy ) - (NoTangent(), ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy) + (NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : dy) end return muladd(ut, v, z), muladd_pullback_2 end @@ -166,7 +166,7 @@ function rrule( @thunk(vec(sum(u .* conj.(Ȳ), dims=1))'), ) addon = if z isa Bool - DoesNotExist() + NoTangent() elseif z isa Number @thunk(sum(Ȳ)) else From e2711f227f8339837cba9b1530c71187dccf17eb Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 2 Jun 2021 12:25:30 +0100 Subject: [PATCH 9/9] Update Project.toml Co-authored-by: Lyndon White --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index fe8528dae..b5200218b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "0.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"