From 5aad94a57206fa86f6228d11d3675d2232f4d4d8 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 14:30:54 +0100 Subject: [PATCH 1/7] replace NO_FIELDS by NoTangent() --- src/rulesets/Base/array.jl | 16 ++--- src/rulesets/Base/arraymath.jl | 20 +++--- src/rulesets/Base/base.jl | 16 ++--- src/rulesets/Base/evalpoly.jl | 2 +- src/rulesets/Base/fastmath_able.jl | 22 +++---- src/rulesets/Base/indexing.jl | 2 +- src/rulesets/Base/mapreduce.jl | 4 +- src/rulesets/Base/sort.jl | 4 +- src/rulesets/LinearAlgebra/blas.jl | 22 +++---- src/rulesets/LinearAlgebra/dense.jl | 30 ++++----- src/rulesets/LinearAlgebra/factorization.jl | 32 +++++----- src/rulesets/LinearAlgebra/matfun.jl | 4 +- src/rulesets/LinearAlgebra/norm.jl | 40 ++++++------ src/rulesets/LinearAlgebra/structured.jl | 66 ++++++++++---------- src/rulesets/LinearAlgebra/symmetric.jl | 24 +++---- src/rulesets/Random/random.jl | 2 +- src/rulesets/Statistics/statistics.jl | 2 +- test/rulesets/Base/base.jl | 4 +- test/rulesets/LinearAlgebra/factorization.jl | 16 ++--- test/rulesets/LinearAlgebra/norm.jl | 10 +-- test/rulesets/LinearAlgebra/structured.jl | 8 +-- test/rulesets/LinearAlgebra/symmetric.jl | 26 ++++---- 22 files changed, 186 insertions(+), 186 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 3b3996373..514090314 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -5,7 +5,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) A_dims = size(A) function reshape_pullback(Ȳ) - return (NO_FIELDS, reshape(Ȳ, A_dims), DoesNotExist()) + return (NoTangent(), reshape(Ȳ, A_dims), DoesNotExist()) end return reshape(A, dims), reshape_pullback end @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) ∂A = reshape(Ȳ, A_dims) ∂dims = broadcast(_ -> DoesNotExist(), dims) - return (NO_FIELDS, ∂A, ∂dims...) + return (NoTangent(), ∂A, ∂dims...) end return reshape(A, dims...), reshape_pullback end @@ -28,7 +28,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) function hcat_pullback(Ȳ) Xs = (A, Bs...) ntuple(length(Bs) + 2) do full_i - full_i == 1 && return NO_FIELDS + full_i == 1 && return NoTangent() i = full_i - 1 l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[:, pre:post] end - return (NO_FIELDS, DoesNotExist(), ∂As) + return (NoTangent(), DoesNotExist(), ∂As) end return reduce(hcat, As), reduce_hcat_pullback end @@ -68,7 +68,7 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) u = l + size(Bs[i], 1) copy(selectdim(Ȳ, 1, l+1:u)) end - return (NO_FIELDS, ∂A, ∂Bs...) + return (NoTangent(), ∂A, ∂Bs...) end return vcat(A, Bs...), vcat_pullback end @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[pre:post, :] end - return (NO_FIELDS, DoesNotExist(), ∂As) + return (NoTangent(), DoesNotExist(), ∂As) end return reduce(vcat, As), reduce_vcat_pullback end @@ -92,14 +92,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), DoesNotExist()) + return (NoTangent(), sum(Ȳ), DoesNotExist()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), ntuple(_->DoesNotExist(), length(dims))...) + return (NoTangent(), sum(Ȳ), ntuple(_->DoesNotExist(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 023ccf7f4..3ce2e851c 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -10,7 +10,7 @@ end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) function inv_pullback(ΔΩ) - return NO_FIELDS, -Ω' * ΔΩ * Ω' + return NoTangent(), -Ω' * ΔΩ * Ω' end return Ω, inv_pullback end @@ -26,7 +26,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(Ȳ * B'), X̄ -> mul!(X̄, Ȳ, B', true, true) @@ -48,7 +48,7 @@ function rrule( function times_pullback(Ȳ) @assert size(B, 1) === 1 # otherwise primal would have failed. return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(Ȳ * vec(B')), X̄ -> mul!(X̄, Ȳ, vec(B'), true, true) @@ -67,7 +67,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), @thunk(dot(Ȳ, B)'), InplaceableThunk( @thunk(A' * Ȳ), @@ -83,7 +83,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(A' * Ȳ), X̄ -> mul!(X̄, conj(A), Ȳ, true, true) @@ -113,7 +113,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R ∂A = last(dA_pb(unthunk(dAᵀ))) ∂B = last(dA_pb(unthunk(dBᵀ))) - (NO_FIELDS, ∂A, ∂B) + (NoTangent(), ∂A, ∂B) end return C, slash_pullback end @@ -133,7 +133,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Ā end ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B + return NoTangent(), ∂A, ∂B end return Y, backslash_pullback @@ -146,7 +146,7 @@ end function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real) Y = A/b function slash_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) + return (NoTangent(), @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) end return Y, slash_pullback end @@ -154,7 +154,7 @@ end function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real}) Y = b\A function backslash_pullback(Ȳ) - return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) + return (NoTangent(), @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) end return Y, backslash_pullback end @@ -165,7 +165,7 @@ end function rrule(::typeof(-), x::AbstractArray) function negation_pullback(ȳ) - return NO_FIELDS, InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ) + return NoTangent(), InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ) end return -x, negation_pullback end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 51da82a4d..0dac73572 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -10,7 +10,7 @@ frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') function rrule(::typeof(adjoint), z::Number) - adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ') + adjoint_pullback(ΔΩ) = (NoTangent(), ΔΩ') return (z', adjoint_pullback) end @@ -22,7 +22,7 @@ frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz)) function rrule(::typeof(real), z::Number) # add zero(z) to embed the real number in the same number type as z - real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z)) + real_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) + zero(z)) return (real(z), real_pullback) end @@ -33,7 +33,7 @@ end frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz)) function rrule(::typeof(imag), z::Complex) - imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im) + imag_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) * im) return (imag(z), imag_pullback) end @@ -45,15 +45,15 @@ function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex end function rrule(::Type{T}, z::Complex) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), Complex(ΔΩ)) return (T(z), Complex_pullback) end function rrule(::Type{T}, x::Real) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ)) return (T(x), Complex_pullback) end function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ), imag(ΔΩ)) return (T(x, y), Complex_pullback) end @@ -70,7 +70,7 @@ end function rrule(::typeof(hypot), z::Complex) Ω = hypot(z) function hypot_pullback(ΔΩ) - return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) + return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) end return (Ω, hypot_pullback) end @@ -148,7 +148,7 @@ end function rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return (x, identity_pullback) end diff --git a/src/rulesets/Base/evalpoly.jl b/src/rulesets/Base/evalpoly.jl index d2e7d3073..60d5f6bae 100644 --- a/src/rulesets/Base/evalpoly.jl +++ b/src/rulesets/Base/evalpoly.jl @@ -15,7 +15,7 @@ if VERSION ≥ v"1.4" y, ys = _evalpoly_intermediates(x, p) function evalpoly_pullback(Δy) ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) - return NO_FIELDS, ∂x, ∂p + return NoTangent(), ∂x, ∂p end return y, evalpoly_pullback end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 3bb0a79f4..0a31e7b9f 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -11,7 +11,7 @@ let ## sin function rrule(::typeof(sin), x::Number) sinx, cosx = sincos(x) - sin_pullback(Δy) = (NO_FIELDS, cosx' * Δy) + sin_pullback(Δy) = (NoTangent(), cosx' * Δy) return (sinx, sin_pullback) end @@ -23,7 +23,7 @@ let ## cos function rrule(::typeof(cos), x::Number) sinx, cosx = sincos(x) - cos_pullback(Δy) = (NO_FIELDS, -sinx' * Δy) + cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) return (cosx, cos_pullback) end @@ -75,7 +75,7 @@ let Ω = abs(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) - return (NO_FIELDS, signx * real(ΔΩ)) + return (NoTangent(), signx * real(ΔΩ)) end return Ω, abs_pullback end @@ -88,7 +88,7 @@ let function rrule(::typeof(abs2), z::Union{Real, Complex}) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) - return (NO_FIELDS, 2real(ΔΩ)*z) + return (NoTangent(), 2real(ΔΩ)*z) end return abs2(z), abs2_pullback end @@ -99,7 +99,7 @@ let end function rrule(::typeof(conj), z::Union{Real, Complex}) function conj_pullback(ΔΩ) - return (NO_FIELDS, conj(ΔΩ)) + return (NoTangent(), conj(ΔΩ)) end return conj(z), conj_pullback end @@ -115,11 +115,11 @@ let function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) - return (NO_FIELDS, Zero()) + return (NoTangent(), Zero()) end function angle_pullback(ΔΩ) Δu, Δv = reim(ΔΩ) - return (NO_FIELDS, im*Δu/ifelse(iszero(x), one(x), x)) + return (NoTangent(), im*Δu/ifelse(iszero(x), one(x), x)) # `ifelse` is applied only to denominator to ensure type-stability. end return angle(x), angle_pullback @@ -130,7 +130,7 @@ let Δu, Δv = reim(ΔΩ) # `ifelse` is applied only to denominator to ensure type-stability. n = ifelse(iszero(z), one(real(z)), abs2(z)) - return (NO_FIELDS, (-y + im*x)*Δu/n) + return (NoTangent(), (-y + im*x)*Δu/n) end return angle(z), angle_pullback end @@ -155,7 +155,7 @@ let Ω = hypot(x, y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) - return (NO_FIELDS, c * x, c * y) + return (NoTangent(), c * x, c * y) end return (Ω, hypot_pullback) end @@ -198,7 +198,7 @@ let Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, sign_pullback end @@ -214,7 +214,7 @@ let function rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) - return (NO_FIELDS, ΔΩ * y', x' * ΔΩ) + return (NoTangent(), ΔΩ * y', x' * ΔΩ) end return x * y, times_pullback end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 78208434c..e99b30036 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -21,7 +21,7 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) getindex_add! ) īnds = broadcast(_ -> DoesNotExist(), inds) - return (NO_FIELDS, x̄, īnds...) + return (NoTangent(), x̄, īnds...) end return y, getindex_pullback diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 2f5b5003a..f4dc2f8a7 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -14,7 +14,7 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} @thunk(broadcast(last∘tuple, x, ȳ)), x -> x .+= ȳ ) - return (NO_FIELDS, x̄) + return (NoTangent(), x̄) end return y, sum_pullback end @@ -51,7 +51,7 @@ function rrule( @thunk(2 .* real.(ȳ) .* x), dx -> dx .+= 2 .* real.(ȳ) .* x ) - return (NO_FIELDS, DoesNotExist(), x_thunk) + return (NoTangent(), DoesNotExist(), x_thunk) end return y, sum_abs2_pullback end diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 1ebf2a0a5..38917459b 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -10,7 +10,7 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!) - return NO_FIELDS, Δxs, DoesNotExist() + return NoTangent(), Δxs, DoesNotExist() end return ys, partialsort_pullback @@ -28,7 +28,7 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!) - return NO_FIELDS, Δxs + return NoTangent(), Δxs end return ys, sort_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 7fcfba18a..445e660a7 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -25,7 +25,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end - return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist()) + return (NoTangent(), DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist()) end return Ω, blas_dot_pullback end @@ -48,19 +48,19 @@ end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) function nrm2_pullback(ΔΩ) - return NO_FIELDS, x .* (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) + return NoTangent(), x .* (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) end return Ω, nrm2_pullback end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - nrm2_pullback(::Zero) = (NO_FIELDS, DoesNotExist(), Zero(), DoesNotExist()) + nrm2_pullback(::Zero) = (NoTangent(), DoesNotExist(), Zero(), DoesNotExist()) function nrm2_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) ∂X = scal!(n, s, blascopy!(n, X, incx, _zeros(X), incx), incx) - return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) + return (NoTangent(), DoesNotExist(), ∂X, DoesNotExist()) end return Ω, nrm2_pullback end @@ -78,19 +78,19 @@ end function rrule(::typeof(BLAS.asum), x) function asum_pullback(ΔΩ) - return (NO_FIELDS, _signcomp.(x) .* real(ΔΩ)) + return (NoTangent(), _signcomp.(x) .* real(ΔΩ)) end return BLAS.asum(x), asum_pullback end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - asum_pullback(::Zero) = (NO_FIELDS, DoesNotExist(), Zero(), DoesNotExist()) + asum_pullback(::Zero) = (NoTangent(), DoesNotExist(), Zero(), DoesNotExist()) function asum_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ)) ∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) - return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) + return (NoTangent(), DoesNotExist(), ∂X, DoesNotExist()) end return Ω, asum_pullback end @@ -135,7 +135,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x̄ -> gemv!('N', α', conj(A), ȳ, one(T), x̄) ) end - return (NO_FIELDS, DoesNotExist(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) + return (NoTangent(), DoesNotExist(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) end return y, gemv_pullback end @@ -146,7 +146,7 @@ function rrule( y, inner_pullback = rrule(gemv, tA, one(T), A, x) function gemv_pullback(Ȳ) (_, dtA, _, dA, dx) = inner_pullback(Ȳ) - return (NO_FIELDS, dtA, dA, dx) + return (NoTangent(), dtA, dA, dx) end return y, gemv_pullback end @@ -249,7 +249,7 @@ function rrule( ) end end - return (NO_FIELDS, DoesNotExist(), DoesNotExist(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) + return (NoTangent(), DoesNotExist(), DoesNotExist(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) end return C, gemm_pullback end @@ -260,7 +260,7 @@ function rrule( C, inner_pullback = rrule(gemm, tA, tB, one(T), A, B) function gemm_pullback(Ȳ) (_, dtA, dtB, _, dA, dB) = inner_pullback(Ȳ) - return (NO_FIELDS, dtA, dtB, dA, dB) + return (NoTangent(), dtA, dtB, dA, dB) end return C, gemm_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 5069f4c61..a8b1e3051 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -8,7 +8,7 @@ end function rrule(::typeof(dot), x, y) function dot_pullback(ΔΩ) - return (NO_FIELDS, @thunk(y .* ΔΩ'), @thunk(x .* ΔΩ)) + return (NoTangent(), @thunk(y .* ΔΩ'), @thunk(x .* ΔΩ)) end return dot(x, y), dot_pullback end @@ -24,9 +24,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N dx = @thunk conj(ΔΩ) .* Ay dA = @thunk ΔΩ .* x .* adjoint(y) dy = @thunk ΔΩ .* (adjoint(A) * x) - return (NO_FIELDS, dx, dA, dy) + return (NoTangent(), dx, dA, dy) end - dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + dot_pullback(::Zero) = (NoTangent(), Zero(), Zero(), Zero()) return z, dot_pullback end @@ -36,9 +36,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused dA = @thunk Diagonal(ΔΩ .* x .* conj(y)) # calculate N not N^2 elements dy = @thunk ΔΩ .* conj.(A.diag) .* x - return (NO_FIELDS, dx, dA, dy) + return (NoTangent(), dx, dA, dy) end - dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + dot_pullback(::Zero) = (NoTangent(), Zero(), Zero(), Zero()) return z, dot_pullback end @@ -54,7 +54,7 @@ end function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) Ω = cross(a, b) function cross_pullback(ΔΩ) - return (NO_FIELDS, @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) + return (NoTangent(), @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) end return Ω, cross_pullback end @@ -76,7 +76,7 @@ function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) function det_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ : Ω * ΔΩ * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, det_pullback end @@ -94,7 +94,7 @@ function rrule(::typeof(logdet), x::Union{Number, StridedMatrix{<:Number}}) Ω = logdet(x) function logdet_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ / x' : ΔΩ * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, logdet_pullback end @@ -122,7 +122,7 @@ function rrule(::typeof(logabsdet), x::AbstractMatrix) imagf = f - real(f) g = real(Δy) + imagf ∂x = g * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, logabsdet_pullback end @@ -139,7 +139,7 @@ function rrule(::typeof(tr), x) # This should really be a FillArray # see https://github.com/JuliaDiff/ChainRules.jl/issues/46 function tr_pullback(ΔΩ) - return (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1)))) + return (NoTangent(), Diagonal(fill(ΔΩ, size(x, 1)))) end return tr(x), tr_pullback end @@ -202,7 +202,7 @@ function rrule( y = pinv(x, tol) function pinv_pullback(Δy) ∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y) - return (NO_FIELDS, ∂x, Zero()) + return (NoTangent(), ∂x, Zero()) end return y, pinv_pullback end @@ -216,7 +216,7 @@ function rrule( function pinv_pullback(Δy) ∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y ∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′) - return (NO_FIELDS, ∂x, Zero()) + return (NoTangent(), ∂x, Zero()) end return y, pinv_pullback end @@ -235,7 +235,7 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} ∂A = add!!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A) ∂A = add!!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y' end - return (NO_FIELDS, ∂A) + return (NoTangent(), ∂A) end return Y, pinv_pullback end @@ -280,7 +280,7 @@ function rrule( trans = T <: Complex ? 'C' : 'T' ∂D, scale2 = LAPACK.trsyl!(trans, trans, RA, RB, ∂Y) ∂Z = rmul!(QA * (∂D * QB'), -inv(scale2)) - return NO_FIELDS, @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) + return NoTangent(), @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) end return Ω, sylvester_pullback end @@ -318,7 +318,7 @@ function rrule( ∂Y = Q' * (∂Ω * Q) ∂D, scale2 = LAPACK.trsyl!(T <: Complex ? 'C' : 'T', 'N', R, R, ∂Y) ∂Z = rmul!(Q * (∂D * Q'), -inv(scale2)) - return NO_FIELDS, @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) + return NoTangent(), @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) end return Ω, lyap_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 5f00320a6..6553f1f30 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -77,7 +77,7 @@ function rrule( F = lu(A, pivot; kwargs...) function lu_pullback(ΔF::Composite) Δfactors = ΔF.factors - Δfactors isa AbstractZero && return (NO_FIELDS, Δfactors, DoesNotExist()) + Δfactors isa AbstractZero && return (NoTangent(), Δfactors, DoesNotExist()) factors = F.factors ∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors ∂A = similar(factors) @@ -127,7 +127,7 @@ function rrule( if pivot === Val(true) ∂A = ∂A[invperm(F.p), :] end - return NO_FIELDS, ∂A, DoesNotExist() + return NoTangent(), ∂A, DoesNotExist() end return F, lu_pullback end @@ -151,10 +151,10 @@ function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:Stri elseif x === :factors Matrix(ΔY) else - return (NO_FIELDS, DoesNotExist(), DoesNotExist()) + return (NoTangent(), DoesNotExist(), DoesNotExist()) end ∂F = Composite{TF}(; factors=∂factors) - return NO_FIELDS, ∂F, DoesNotExist() + return NoTangent(), ∂F, DoesNotExist() end return getproperty(F, x), getproperty_LU_pullback end @@ -194,7 +194,7 @@ function rrule(::typeof(inv), F::LU{<:Any,<:StridedMatrix}) triu!(rdiv!(∂factors, U')) ∂factors .+= ∂L ∂F = Composite{typeof(F)}(; factors=∂factors) - return NO_FIELDS, ∂F + return NoTangent(), ∂F end return inv(F), inv_LU_pullback end @@ -208,7 +208,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) function svd_pullback(Ȳ::Composite) # `getproperty` on `Composite`s ensures we have no thunks. ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt') - return (NO_FIELDS, ∂X) + return (NoTangent(), ∂X) end return F, svd_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD elseif x === :Vt C(Vt=Ȳ,) end - return NO_FIELDS, ∂F, DoesNotExist() + return NoTangent(), ∂F, DoesNotExist() end return getproperty(F, x), getproperty_svd_pullback end @@ -301,7 +301,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ function eigen_pullback(ΔF::Composite) λ, V = F.values, F.vectors Δλ, ΔV = ΔF.values, ΔF.vectors - ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) + ΔV isa AbstractZero && Δλ isa AbstractZero && return (NoTangent(), Δλ + ΔV) if eltype(λ) <: Real && ishermitian(A) hermA = Hermitian(A) ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) @@ -319,9 +319,9 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ ∂K[diagind(∂K)] .= Δλ ∂A = mul!(∂K, V' \ ∂K, V') end - return NO_FIELDS, T <: Real ? real(∂A) : ∂A + return NoTangent(), T <: Real ? real(∂A) : ∂A end - eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + eigen_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, eigen_pullback end @@ -409,7 +409,7 @@ function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Unio function eigvals_pullback(Δλ) ∂F = Composite{typeof(F)}(values = Δλ) _, ∂A = eigen_back(∂F) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return λ, eigvals_pullback end @@ -432,7 +432,7 @@ end function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Composite) - return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist() + return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist() end return C, cholesky_pullback end @@ -441,7 +441,7 @@ function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Boo C = cholesky(A, Val(false); check=check) function cholesky_pullback(ΔC::Composite) Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) - return NO_FIELDS, Ā, DoesNotExist() + return NoTangent(), Ā, DoesNotExist() end return C, cholesky_pullback end @@ -459,7 +459,7 @@ function rrule( function cholesky_pullback(ΔC::Composite) Ā, U = _cholesky_pullback_shared_code(C, ΔC) Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) - return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist() + return NoTangent(), _symhermtype(A)(Ā), DoesNotExist() end return C, cholesky_pullback end @@ -476,7 +476,7 @@ function rrule( Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist()) + return (NoTangent(), UpperTriangular(Ā), DoesNotExist()) end return C, cholesky_pullback end @@ -507,7 +507,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} C(U=UpperTriangular(Ȳ'),) end end - return NO_FIELDS, ∂F, DoesNotExist() + return NoTangent(), ∂F, DoesNotExist() end return getproperty(F, x), getproperty_cholesky_pullback end diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 28402e5b5..f28c0a527 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -122,7 +122,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) hermX, hermX_intermediates = _matfun(exp, hermA) function exp_pullback_hermitian(ΔX) ∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates) - return NO_FIELDS, Matrix(∂hermA) + return NoTangent(), Matrix(∂hermA) end return Matrix(hermX), exp_pullback_hermitian else @@ -133,7 +133,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) # the default _matfun_frechet_adjoint! ∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')' ∂A = _matfun_frechet_adjoint!(exp, ∂X, A, X, intermediates) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return X, exp_pullback end diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 41792ad4d..e6c85bf77 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -52,9 +52,9 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) end ) ∂p = @thunk _normp_back_p(x, p, y, Δy) - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) + norm_pullback_p(::Zero) = (NoTangent(), Zero(), Zero()) return y, norm_pullback_p end @@ -74,9 +74,9 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) _norm2_back!(dx, x, y, Δy) end ) - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end - norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) + norm_pullback_2(::Zero) = (NoTangent(), Zero()) return y, norm_pullback_2 end @@ -100,9 +100,9 @@ function rrule(::typeof(norm), x::Number, p::Real) signx = x isa Real ? sign(x) : x * pinv(y) signx * real(Δy) end - return (NO_FIELDS, ∂x, Zero()) + return (NoTangent(), ∂x, Zero()) end - norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) + norm_pullback(::Zero) = (NoTangent(), Zero(), Zero()) return y, norm_pullback end @@ -115,9 +115,9 @@ function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) ∂p = @thunk _normp_back_p(x, p, y, Δy) - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - normp_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) + normp_pullback(::Zero) = (NoTangent(), Zero(), Zero()) return y, normp_pullback end @@ -158,15 +158,15 @@ end function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) - normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero()) + normMinusInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normMinusInf_pullback(::Zero) = (NoTangent(), Zero()) return y, normMinusInf_pullback end function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) - normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normInf_pullback(::Zero) = (NO_FIELDS, Zero()) + normInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normInf_pullback(::Zero) = (NoTangent(), Zero()) return y, normInf_pullback end @@ -189,11 +189,11 @@ end function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) y = LinearAlgebra.norm1(x) - norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + norm1_pullback(Δy) = (NoTangent(), InplaceableThunk( @thunk(_norm1_back(x, y, Δy)), dx -> _norm1_back!(dx, x, y, Δy), )) - norm1_pullback(::Zero) = (NO_FIELDS, Zero()) + norm1_pullback(::Zero) = (NoTangent(), Zero()) return y, norm1_pullback end @@ -221,11 +221,11 @@ end function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) y = LinearAlgebra.norm2(x) - norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + norm2_pullback(Δy) = (NoTangent(), InplaceableThunk( @thunk(_norm2_back(x, y, Δy)), dx -> _norm2_back!(dx, x, y, Δy), )) - norm2_pullback(::Zero) = (NO_FIELDS, Zero()) + norm2_pullback(::Zero) = (NoTangent(), Zero()) return y, norm2_pullback end @@ -261,9 +261,9 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) ∂nrm = -dot(y, Δy) * invnrm (_, ∂xnorm, ∂p) = inner_pullback(∂nrm) ∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) + normalize_pullback(::Zero) = (NoTangent(), Zero(), Zero()) return y, normalize_pullback end @@ -274,8 +274,8 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}) LinearAlgebra.__normalize!(y, nrm) function normalize_pullback(Δy) ∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm) - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end - normalize_pullback(::Zero) = (NO_FIELDS, Zero()) + normalize_pullback(::Zero) = (NoTangent(), Zero()) return y, normalize_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 04f359acb..89096bece 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -10,7 +10,7 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr function slash_pullback(Ȳ) ∂A = @thunk Ȳ / B' ∂B = @thunk _unionall_wrapper(T)(-Y' * (Ȳ / B')) - return (NO_FIELDS, ∂A, ∂B) + return (NoTangent(), ∂A, ∂B) end return Y, slash_pullback end @@ -20,7 +20,7 @@ function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMa function backslash_pullback(Ȳ) ∂A = @thunk _unionall_wrapper(T)(-(A' \ Ȳ) * Y') ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B + return NoTangent(), ∂A, ∂B end return Y, backslash_pullback end @@ -31,42 +31,42 @@ end function rrule(::Type{<:Diagonal}, d::AbstractVector) function Diagonal_pullback(ȳ::AbstractMatrix) - return (NO_FIELDS, diag(ȳ)) + return (NoTangent(), diag(ȳ)) end function Diagonal_pullback(ȳ::Composite) # TODO: Assert about the primal type in the Composite, It should be Diagonal # infact it should be exactly the type of `Diagonal(d)` # but right now Zygote loses primal type information so we can't use it. # See https://github.com/FluxML/Zygote.jl/issues/603 - return (NO_FIELDS, ȳ.diag) + return (NoTangent(), ȳ.diag) end return Diagonal(d), Diagonal_pullback end function rrule(::typeof(diag), A::AbstractMatrix) function diag_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ)) + return (NoTangent(), Diagonal(ȳ)) end return diag(A), diag_pullback end if VERSION ≥ v"1.3" function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) function diag_pullback(ȳ) - return (NO_FIELDS, diagm(size(A)..., k => ȳ), DoesNotExist()) + return (NoTangent(), diagm(size(A)..., k => ȳ), DoesNotExist()) end return diag(A, k), diag_pullback end function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, DoesNotExist(), DoesNotExist(), _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), DoesNotExist(), DoesNotExist(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(m, n, kv...), diagm_pullback end end function rrule(::typeof(diagm), kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(kv...), diagm_pullback end @@ -79,7 +79,7 @@ end function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) + return (NoTangent(), @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) end return D * V, times_pullback end @@ -89,26 +89,26 @@ end ##### function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) - Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + Adjoint_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return Adjoint(A), Adjoint_pullback end function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) - Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + Adjoint_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + Adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return Adjoint(A), Adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) - adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + adjoint_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) - adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + adjoint_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return adjoint(A), adjoint_pullback end @@ -117,26 +117,26 @@ end ##### function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) - Transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) + Transpose_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + Transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), Transpose(ȳ)) return Transpose(A), Transpose_pullback end function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) - Transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) + Transpose_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + Transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(Transpose(ȳ))) return Transpose(A), Transpose_pullback end function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) - transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) - transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) + transpose_pullback(ȳ::Composite) = (NoTangent(), ȳ.parent) + transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), transpose(ȳ)) return transpose(A), transpose_pullback end function rrule(::typeof(transpose), A::AbstractVector{<:Number}) - transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) - transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) + transpose_pullback(ȳ::Composite) = (NoTangent(), vec(ȳ.parent)) + transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(transpose(ȳ))) return transpose(A), transpose_pullback end @@ -146,40 +146,40 @@ end function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) function UpperTriangular_pullback(ȳ) - return (NO_FIELDS, Matrix(ȳ)) + return (NoTangent(), Matrix(ȳ)) end return UpperTriangular(A), UpperTriangular_pullback end function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) function LowerTriangular_pullback(ȳ) - return (NO_FIELDS, Matrix(ȳ)) + return (NoTangent(), Matrix(ȳ)) end return LowerTriangular(A), LowerTriangular_pullback end function rrule(::typeof(triu), A::AbstractMatrix, k::Integer) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ, k), DoesNotExist()) + return (NoTangent(), triu(ȳ, k), DoesNotExist()) end return triu(A, k), triu_pullback end function rrule(::typeof(triu), A::AbstractMatrix) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ)) + return (NoTangent(), triu(ȳ)) end return triu(A), triu_pullback end function rrule(::typeof(tril), A::AbstractMatrix, k::Integer) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ, k), DoesNotExist()) + return (NoTangent(), tril(ȳ, k), DoesNotExist()) end return tril(A, k), tril_pullback end function rrule(::typeof(tril), A::AbstractMatrix) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ)) + return (NoTangent(), tril(ȳ)) end return tril(A), tril_pullback end @@ -191,7 +191,7 @@ function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) y = det(X) s = conj!(y ./ _diag_view(X)) function det_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ .* s)) + return (NoTangent(), Diagonal(ȳ .* s)) end return y, det_pullback end @@ -200,7 +200,7 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) y = logdet(X) s = conj!(one(eltype(X)) ./ _diag_view(X)) function logdet_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ .* s)) + return (NoTangent(), Diagonal(ȳ .* s)) end return y, logdet_pullback end diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index d4e6d7d68..7178c337b 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -9,7 +9,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) @inline function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), DoesNotExist()) + return (NoTangent(), _symherm_back(typeof(Ω), ΔΩ, uplo), DoesNotExist()) end return Ω, HermOrSym_pullback end @@ -27,7 +27,7 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} uplo = A.uplo ∂A = T∂A(_symherm_back(typeof(A), ΔΩ, Symbol(uplo)), uplo) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return TM(A), Matrix_pullback end @@ -133,9 +133,9 @@ function rrule( Δλ, ΔU = ΔF.values, ΔF.vectors ΔU = ΔU isa AbstractZero ? ΔU : copy(ΔU) ∂A = eigen_rev!(A, λ, U, Δλ, ΔU) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end - eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + eigen_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, eigen_pullback end @@ -226,7 +226,7 @@ function rrule( function eigvals_pullback(Δλ) ∂F = Composite{typeof(F)}(values = Δλ) _, ∂A = eigen_back(∂F) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return λ, eigvals_pullback end @@ -251,9 +251,9 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.Bla ∂U = ΔF.U .+ (ΔF.V .+ ΔF.Vt') .* c' end ∂A = eigen_rev!(A, λ, U, ∂λ, ∂U) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end - svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + svd_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, svd_pullback end @@ -286,9 +286,9 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS invpermute!(∂λ, p) ∂λ .*= sign.(λ) _, ∂A = back(∂λ) - return NO_FIELDS, unthunk(∂A) + return NoTangent(), unthunk(∂A) end - svdvals_pullback(ΔS::AbstractZero) = (NO_FIELDS, ΔS) + svdvals_pullback(ΔS::AbstractZero) = (NoTangent(), ΔS) return S, svdvals_pullback end @@ -324,7 +324,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates) # the cotangent of Hermitian A should be Hermitian ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, $(Symbol(func, :_pullback)) end @@ -351,7 +351,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo) Y = (sinA, cosA) function sincos_pullback((ΔsinA, ΔcosA)::Composite) - ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA + ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NoTangent(), ΔsinA + ΔcosA if eltype(A) <: Real ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) end @@ -369,7 +369,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U')) end ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, sincos_pullback end diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index 8da0b1f22..29997c2a1 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -2,7 +2,7 @@ frule(Δargs, ::Type{MersenneTwister}, args...) = MersenneTwister(args...), Zero function rrule(::Type{MersenneTwister}, args...) function MersenneTwister_pullback(ΔΩ) - return (NO_FIELDS, map(_ -> Zero(), args)...) + return (NoTangent(), map(_ -> Zero(), args)...) end return MersenneTwister(args...), MersenneTwister_pullback end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 6b15b10fa..6f3f72037 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -14,7 +14,7 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) function mean_pullback(ȳ) _, ∂sum_x = sum_pullback(ȳ) ∂x = extern(∂sum_x) / n - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return y_sum / n, mean_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a3ab3ca0d..6abd7dfc2 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -165,12 +165,12 @@ end @testset "type" begin - @test frule((NO_FIELDS, DoesNotExist(), DoesNotExist()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing + @test frule((NoTangent(), DoesNotExist(), DoesNotExist()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing @test rrule(typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing end @testset "Logging" begin - @test frule((NO_FIELDS, DoesNotExist(), DoesNotExist()), Base.depwarn, "message", :f) !== nothing + @test frule((NoTangent(), DoesNotExist(), DoesNotExist()), Base.depwarn, "message", :f) !== nothing @test rrule(Base.depwarn, "message", :f) !== nothing end end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index f4b88fead..f398de7ea 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -187,12 +187,12 @@ end @test X̄_vectors_ad ≈ j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1] rtol=1e-4 F̄ = CT(values = λ̄, vectors = V̄) s̄elf, X̄_ad = @inferred back(F̄) - @test s̄elf === NO_FIELDS + @test s̄elf === NoTangent() X̄_fd = j′vp(_fdm, asnt ∘ eigen, F̄, X)[1] @test X̄_ad ≈ X̄_fd rtol=1e-4 - @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) === (NoTangent(), Zero()) F̄zero = CT(values = Zero(), vectors = Zero()) - @test @inferred(back(F̄zero)) === (NO_FIELDS, Zero()) + @test @inferred(back(F̄zero)) === (NoTangent(), Zero()) T <: Real && @testset "cotangent is real when input is" begin X = randn(T, n, n) @@ -286,7 +286,7 @@ end end ∂self, ∂A = @inferred back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) ∂A_fd = j′vp(_fdm, f_stable, ∂F_stable, A)[1] @test ∂A ≈ ∂A_fd @@ -318,7 +318,7 @@ end λ, back = rrule(eigvals, randn(T, n, n)) _, X̄ = @inferred back(rand_tangent(λ)) - @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) === (NoTangent(), Zero()) T <: Real && @testset "cotangent is real when input is" begin @test eltype(X̄) <: Real @@ -343,10 +343,10 @@ end λ_ad, back = rrule(eigvals, A) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂A = @inferred back(Δλ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) end end end @@ -375,7 +375,7 @@ end Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NO_FIELDS + @test dself === NoTangent() @test dp === DoesNotExist() # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index bb2d20bfd..88c2a80a8 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -159,7 +159,7 @@ test_rrule(norm, randn(T), p) _, back = rrule(norm, randn(T), p) - @test back(Zero()) == (NO_FIELDS, Zero(), Zero()) + @test back(Zero()) == (NoTangent(), Zero(), Zero()) end @testset "p = 0" begin p = 0.0 @@ -172,8 +172,8 @@ @test iszero(ẏ) y_rev, back = rrule(norm, x, p) @test y_rev == y - @test back(ȳ) == (NO_FIELDS, zero(x), Zero()) - @test back(Zero()) == (NO_FIELDS, Zero(), Zero()) + @test back(ȳ) == (NoTangent(), zero(x), Zero()) + @test back(Zero()) == (NoTangent(), Zero(), Zero()) end end end @@ -185,12 +185,12 @@ end @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) test_rrule(normalize, x) - @test rrule(normalize, x)[2](Zero()) === (NO_FIELDS, Zero()) + @test rrule(normalize, x)[2](Zero()) === (NoTangent(), Zero()) end @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable x = randn(T, 3) test_rrule(normalize, x, p) - @test rrule(normalize, x, p)[2](Zero()) === (NO_FIELDS, Zero(), Zero()) + @test rrule(normalize, x, p)[2](Zero()) === (NoTangent(), Zero(), Zero()) end end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 4bf3fbdda..fee31007b 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -23,9 +23,9 @@ # TODO: replace this with a `rrule_test` once we have that working # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24 res, pb = rrule(Diagonal, [1, 4]) - @test pb(10*res) == (NO_FIELDS, [10, 40]) + @test pb(10*res) == (NoTangent(), [10, 40]) comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal - @test pb(comp) == (NO_FIELDS, [10, 40]) + @test pb(comp) == (NoTangent(), [10, 40]) end @testset "dot(x, ::Diagonal, y)" begin N = 4 @@ -57,7 +57,7 @@ y, back = rrule(diagm, ps...) @test y == diagm(ps...) ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) ∂px = unthunk(∂px) @@ -76,7 +76,7 @@ y, back = rrule(diagm, M, N, ps...) @test y == diagm(M, N, ps...) ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂M === DoesNotExist() @test ∂N === DoesNotExist() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 30943bb96..a873fe611 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -133,7 +133,7 @@ end ∂self, ∂symA = @inferred back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @@ -141,8 +141,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) - @test @inferred(back(CT())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) + @test @inferred(back(CT())) == (NoTangent(), Zero()) end # when value used to determine phase convention is low, the usual derivatives @@ -207,7 +207,7 @@ λ_ad, back = @inferred rrule(eigvals, symA) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂symA = @inferred back(Δλ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo @@ -215,7 +215,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) end end end @@ -264,7 +264,7 @@ VERSION ≥ v"1.6.0-DEV.1686" && @inferred back(∂F) ∂self, ∂symA = back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @@ -272,8 +272,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) - @test @inferred(back(CT())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) + @test @inferred(back(CT())) == (NoTangent(), Zero()) end @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), @@ -287,7 +287,7 @@ S_ad, back = @inferred rrule(svdvals, symA) @test S_ad ≈ S # inexact because rrule uses svd not svdvals ∂self, ∂symA = @inferred back(ΔS) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo @@ -295,7 +295,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] - @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(Zero())) == (NoTangent(), Zero()) end end @@ -412,7 +412,7 @@ @test typeof(Y_ad) === typeof(Y) hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo ∂self, ∂A = @inferred back(ΔY) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A.uplo == A.uplo # check pullback composes correctly @@ -496,11 +496,11 @@ ΔY = Composite{typeof(Y)}(ΔsinA, ΔcosA) ∂self, ∂A = @inferred back(ΔY) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A ≈ rrule(sin, A)[2](ΔsinA)[2] + rrule(cos, A)[2](ΔcosA)[2] ΔY2 = Composite{typeof(Y)}(Zero(), Zero()) - @test @inferred(back(ΔY2)) === (NO_FIELDS, Zero()) + @test @inferred(back(ΔY2)) === (NoTangent(), Zero()) ΔY3 = Composite{typeof(Y)}(ΔsinA, Zero()) @test @inferred(back(ΔY3)) == rrule(sin, A)[2](ΔsinA) From adfb5affd318df136e3fbb8ef861c6858ab3d2e7 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 14:31:58 +0100 Subject: [PATCH 2/7] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c12e8c0eb..1409d13a4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.65" +version = "0.7.68" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From f878883c47ba3dffd10626904f472e25a9ec8ef1 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 12:39:37 +0100 Subject: [PATCH 3/7] replace NO_FIELDS by NoTangent() --- src/rulesets/Base/array.jl | 12 +++--- src/rulesets/Base/arraymath.jl | 6 +-- src/rulesets/Base/fastmath_able.jl | 2 +- src/rulesets/Base/indexing.jl | 2 +- src/rulesets/Base/mapreduce.jl | 4 +- src/rulesets/Base/sort.jl | 2 +- src/rulesets/LinearAlgebra/blas.jl | 14 +++---- src/rulesets/LinearAlgebra/dense.jl | 8 ++-- src/rulesets/LinearAlgebra/factorization.jl | 22 +++++------ src/rulesets/LinearAlgebra/norm.jl | 26 ++++++------- src/rulesets/LinearAlgebra/structured.jl | 40 ++++++++++---------- src/rulesets/LinearAlgebra/symmetric.jl | 8 ++-- src/rulesets/Random/random.jl | 2 +- test/rulesets/Base/base.jl | 4 +- test/rulesets/LinearAlgebra/factorization.jl | 10 ++--- test/rulesets/LinearAlgebra/norm.jl | 10 ++--- test/rulesets/LinearAlgebra/structured.jl | 6 +-- test/rulesets/LinearAlgebra/symmetric.jl | 14 +++---- test/rulesets/Random/random.jl | 2 +- 19 files changed, 97 insertions(+), 97 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 8a1a88f63..b1e1207e9 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -5,7 +5,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) A_dims = size(A) function reshape_pullback(Ȳ) - return (NO_FIELDS, reshape(Ȳ, A_dims), NoTangent()) + return (NoTangent(), reshape(Ȳ, A_dims), NoTangent()) end return reshape(A, dims), reshape_pullback end @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) ∂A = reshape(Ȳ, A_dims) ∂dims = broadcast(_ -> NoTangent(), dims) - return (NO_FIELDS, ∂A, ∂dims...) + return (NoTangent(), ∂A, ∂dims...) end return reshape(A, dims...), reshape_pullback end @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[:, pre:post] end - return (NO_FIELDS, NoTangent(), ∂As) + return (NoTangent(), NoTangent(), ∂As) end return reduce(hcat, As), reduce_hcat_pullback end @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[pre:post, :] end - return (NO_FIELDS, NoTangent(), ∂As) + return (NoTangent(), NoTangent(), ∂As) end return reduce(vcat, As), reduce_vcat_pullback end @@ -92,14 +92,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), NoTangent()) + return (NoTangent(), sum(Ȳ), NoTangent()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), ntuple(_->NoTangent(), length(dims))...) + return (NoTangent(), sum(Ȳ), ntuple(_->NoTangent(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index d5e13d8eb..5b9baf716 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -127,7 +127,7 @@ function rrule( dz -> sum!(dz, Ȳ; init=false) ) end - (NO_FIELDS, matmul..., addon) + (NoTangent(), matmul..., addon) end return muladd(A, B, z), muladd_pullback_1 end @@ -148,7 +148,7 @@ function rrule( @thunk(ut' .* dy), dv -> dv .+= ut' .* dy ) - (NO_FIELDS, ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy) + (NoTangent(), ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy) end return muladd(ut, v, z), muladd_pullback_2 end @@ -175,7 +175,7 @@ function rrule( dz -> sum!(dz, Ȳ; init=false) ) end - (NO_FIELDS, proj..., addon) + (NoTangent(), proj..., addon) end return muladd(u, vt, z), muladd_pullback_3 end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 2c5fa7d4c..2ddf1b4b1 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -115,7 +115,7 @@ let function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) - return (NO_FIELDS, ZeroTangent()) + return (NoTangent(), ZeroTangent()) end function angle_pullback(ΔΩ) Δu, Δv = reim(ΔΩ) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index bc945f8f6..5f8610ff5 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -21,7 +21,7 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) getindex_add! ) īnds = broadcast(_ -> NoTangent(), inds) - return (NO_FIELDS, x̄, īnds...) + return (NoTangent(), x̄, īnds...) end return y, getindex_pullback diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index d44e98608..33d9fa563 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -51,7 +51,7 @@ function rrule( @thunk(2 .* real.(ȳ) .* x), dx -> dx .+= 2 .* real.(ȳ) .* x ) - return (NO_FIELDS, NoTangent(), x_thunk) + return (NoTangent(), NoTangent(), x_thunk) end return y, sum_abs2_pullback end @@ -85,7 +85,7 @@ function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:Commutativ dx .+= conj.(y ./ x) .* dy end ) - return (NO_FIELDS, x_thunk) + return (NoTangent(), x_thunk) end return y, prod_pullback end diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 8372995bd..c803dedb0 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -10,7 +10,7 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!) - return NO_FIELDS, Δxs, NoTangent() + return NoTangent(), Δxs, NoTangent() end return ys, partialsort_pullback diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index d3eda6738..d40e39701 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -25,7 +25,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end - return (NO_FIELDS, NoTangent(), ∂X, NoTangent(), ∂Y, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent(), ∂Y, NoTangent()) end return Ω, blas_dot_pullback end @@ -55,12 +55,12 @@ end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - nrm2_pullback(::ZeroTangent) = (NO_FIELDS, NoTangent(), ZeroTangent(), NoTangent()) + nrm2_pullback(::ZeroTangent) = (NoTangent(), NoTangent(), ZeroTangent(), NoTangent()) function nrm2_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) ∂X = scal!(n, s, blascopy!(n, X, incx, _zeros(X), incx), incx) - return (NO_FIELDS, NoTangent(), ∂X, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent()) end return Ω, nrm2_pullback end @@ -85,12 +85,12 @@ end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - asum_pullback(::ZeroTangent) = (NO_FIELDS, NoTangent(), ZeroTangent(), NoTangent()) + asum_pullback(::ZeroTangent) = (NoTangent(), NoTangent(), ZeroTangent(), NoTangent()) function asum_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ)) ∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) - return (NO_FIELDS, NoTangent(), ∂X, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent()) end return Ω, asum_pullback end @@ -135,7 +135,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x̄ -> gemv!('N', α', conj(A), ȳ, one(T), x̄) ) end - return (NO_FIELDS, NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) + return (NoTangent(), NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) end return y, gemv_pullback end @@ -249,7 +249,7 @@ function rrule( ) end end - return (NO_FIELDS, NoTangent(), NoTangent(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) + return (NoTangent(), NoTangent(), NoTangent(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) end return C, gemm_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 4bb4b73bb..30769eca3 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -26,7 +26,7 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N dy = @thunk ΔΩ .* (adjoint(A) * x) return (NoTangent(), dx, dA, dy) end - dot_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent(), ZeroTangent()) + dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) return z, dot_pullback end @@ -38,7 +38,7 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} dy = @thunk ΔΩ .* conj.(A.diag) .* x return (NoTangent(), dx, dA, dy) end - dot_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent(), ZeroTangent()) + dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) return z, dot_pullback end @@ -202,7 +202,7 @@ function rrule( y = pinv(x, tol) function pinv_pullback(Δy) ∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y) - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end return y, pinv_pullback end @@ -216,7 +216,7 @@ function rrule( function pinv_pullback(Δy) ∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y ∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′) - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end return y, pinv_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index abbd975ee..109b16f80 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -77,7 +77,7 @@ function rrule( F = lu(A, pivot; kwargs...) function lu_pullback(ΔF::Tangent) Δfactors = ΔF.factors - Δfactors isa AbstractZero && return (NO_FIELDS, Δfactors, NoTangent()) + Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent()) factors = F.factors ∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors ∂A = similar(factors) @@ -127,7 +127,7 @@ function rrule( if pivot === Val(true) ∂A = ∂A[invperm(F.p), :] end - return NO_FIELDS, ∂A, NoTangent() + return NoTangent(), ∂A, NoTangent() end return F, lu_pullback end @@ -151,10 +151,10 @@ function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:Stri elseif x === :factors Matrix(ΔY) else - return (NO_FIELDS, NoTangent(), NoTangent()) + return (NoTangent(), NoTangent(), NoTangent()) end ∂F = Tangent{TF}(; factors=∂factors) - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_LU_pullback end @@ -194,7 +194,7 @@ function rrule(::typeof(inv), F::LU{<:Any,<:StridedMatrix}) triu!(rdiv!(∂factors, U')) ∂factors .+= ∂L ∂F = Tangent{typeof(F)}(; factors=∂factors) - return NO_FIELDS, ∂F + return NoTangent(), ∂F end return inv(F), inv_LU_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD elseif x === :Vt C(Vt=Ȳ,) end - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_svd_pullback end @@ -432,7 +432,7 @@ end function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Tangent) - return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() + return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() end return C, cholesky_pullback end @@ -441,7 +441,7 @@ function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Boo C = cholesky(A, Val(false); check=check) function cholesky_pullback(ΔC::Tangent) Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) - return NO_FIELDS, Ā, NoTangent() + return NoTangent(), Ā, NoTangent() end return C, cholesky_pullback end @@ -459,7 +459,7 @@ function rrule( function cholesky_pullback(ΔC::Tangent) Ā, U = _cholesky_pullback_shared_code(C, ΔC) Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) - return NO_FIELDS, _symhermtype(A)(Ā), NoTangent() + return NoTangent(), _symhermtype(A)(Ā), NoTangent() end return C, cholesky_pullback end @@ -476,7 +476,7 @@ function rrule( Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - return (NO_FIELDS, UpperTriangular(Ā), NoTangent()) + return (NoTangent(), UpperTriangular(Ā), NoTangent()) end return C, cholesky_pullback end @@ -507,7 +507,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} C(U=UpperTriangular(Ȳ'),) end end - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_cholesky_pullback end diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 8a120c28f..4c90cbfe1 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -54,7 +54,7 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NoTangent(), ∂x, ∂p) end - norm_pullback_p(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + norm_pullback_p(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, norm_pullback_p end @@ -76,7 +76,7 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) ) return (NoTangent(), ∂x) end - norm_pullback_2(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm_pullback_2(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm_pullback_2 end @@ -100,9 +100,9 @@ function rrule(::typeof(norm), x::Number, p::Real) signx = x isa Real ? sign(x) : x * pinv(y) signx * real(Δy) end - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end - norm_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + norm_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, norm_pullback end @@ -117,7 +117,7 @@ function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NoTangent(), ∂x, ∂p) end - normp_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + normp_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, normp_pullback end @@ -158,15 +158,15 @@ end function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) - normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normMinusInf_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normMinusInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normMinusInf_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normMinusInf_pullback end function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) - normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normInf_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normInf_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normInf_pullback end @@ -193,7 +193,7 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) @thunk(_norm1_back(x, y, Δy)), dx -> _norm1_back!(dx, x, y, Δy), )) - norm1_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm1_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm1_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) @thunk(_norm2_back(x, y, Δy)), dx -> _norm2_back!(dx, x, y, Δy), )) - norm2_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm2_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm2_pullback end @@ -263,7 +263,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) ∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm return (NoTangent(), ∂x, ∂p) end - normalize_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, normalize_pullback end @@ -276,6 +276,6 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}) ∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm) return (NoTangent(), ∂x) end - normalize_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normalize_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 7687baaa5..5203f46c1 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -52,14 +52,14 @@ end if VERSION ≥ v"1.3" function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) function diag_pullback(ȳ) - return (NO_FIELDS, diagm(size(A)..., k => ȳ), NoTangent()) + return (NoTangent(), diagm(size(A)..., k => ȳ), NoTangent()) end return diag(A, k), diag_pullback end function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(m, n, kv...), diagm_pullback end @@ -89,26 +89,26 @@ end ##### function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) - Adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + Adjoint_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return Adjoint(A), Adjoint_pullback end function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) - Adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + Adjoint_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + Adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return Adjoint(A), Adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) - adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + adjoint_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) - adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + adjoint_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return adjoint(A), adjoint_pullback end @@ -117,26 +117,26 @@ end ##### function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) - Transpose_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) + Transpose_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + Transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), Transpose(ȳ)) return Transpose(A), Transpose_pullback end function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) - Transpose_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) + Transpose_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + Transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(Transpose(ȳ))) return Transpose(A), Transpose_pullback end function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) - transpose_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) + transpose_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), transpose(ȳ)) return transpose(A), transpose_pullback end function rrule(::typeof(transpose), A::AbstractVector{<:Number}) - transpose_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) + transpose_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(transpose(ȳ))) return transpose(A), transpose_pullback end @@ -160,7 +160,7 @@ end function rrule(::typeof(triu), A::AbstractMatrix, k::Integer) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ, k), NoTangent()) + return (NoTangent(), triu(ȳ, k), NoTangent()) end return triu(A, k), triu_pullback end @@ -173,7 +173,7 @@ end function rrule(::typeof(tril), A::AbstractMatrix, k::Integer) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ, k), NoTangent()) + return (NoTangent(), tril(ȳ, k), NoTangent()) end return tril(A, k), tril_pullback end diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 2ecd73a79..6576a1820 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -9,7 +9,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) @inline function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) + return (NoTangent(), _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) end return Ω, HermOrSym_pullback end @@ -324,7 +324,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates) # the cotangent of Hermitian A should be Hermitian ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, $(Symbol(func, :_pullback)) end @@ -351,7 +351,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo) Y = (sinA, cosA) function sincos_pullback((ΔsinA, ΔcosA)::Tangent) - ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA + ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NoTangent(), ΔsinA + ΔcosA if eltype(A) <: Real ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) end @@ -369,7 +369,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U')) end ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, sincos_pullback end diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index c68741987..20fa10971 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -2,7 +2,7 @@ frule(Δargs, ::Type{MersenneTwister}, args...) = MersenneTwister(args...), Zero function rrule(::Type{MersenneTwister}, args...) function MersenneTwister_pullback(ΔΩ) - return (NO_FIELDS, map(_ -> ZeroTangent(), args)...) + return (NoTangent(), map(_ -> ZeroTangent(), args)...) end return MersenneTwister(args...), MersenneTwister_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 74a634238..91203fb86 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -165,12 +165,12 @@ end @testset "type" begin - @test frule((NO_FIELDS, NoTangent(), NoTangent()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing + @test frule((NoTangent(), NoTangent(), NoTangent()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing @test rrule(typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing end @testset "Logging" begin - @test frule((NO_FIELDS, NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing + @test frule((NoTangent(), NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing @test rrule(Base.depwarn, "message", :f) !== nothing end end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 72a47cb77..397f1622f 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -190,9 +190,9 @@ end @test s̄elf === NoTangent() X̄_fd = j′vp(_fdm, asnt ∘ eigen, F̄, X)[1] @test X̄_ad ≈ X̄_fd rtol=1e-4 - @test @inferred(back(ZeroTangent())) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) === (NoTangent(), ZeroTangent()) F̄zero = CT(values = ZeroTangent(), vectors = ZeroTangent()) - @test @inferred(back(F̄zero)) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(F̄zero)) === (NoTangent(), ZeroTangent()) T <: Real && @testset "cotangent is real when input is" begin X = randn(T, n, n) @@ -318,7 +318,7 @@ end λ, back = rrule(eigvals, randn(T, n, n)) _, X̄ = @inferred back(rand_tangent(λ)) - @test @inferred(back(ZeroTangent())) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) === (NoTangent(), ZeroTangent()) T <: Real && @testset "cotangent is real when input is" begin @test eltype(X̄) <: Real @@ -346,7 +346,7 @@ end @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end end @@ -375,7 +375,7 @@ end Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NO_FIELDS + @test dself === NoTangent() @test dp === NoTangent() # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index ebd7b0908..4420b34c7 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -159,7 +159,7 @@ test_rrule(norm, randn(T), p) _, back = rrule(norm, randn(T), p) - @test back(ZeroTangent()) == (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test back(ZeroTangent()) == (NoTangent(), ZeroTangent(), ZeroTangent()) end @testset "p = 0" begin p = 0.0 @@ -172,8 +172,8 @@ @test iszero(ẏ) y_rev, back = rrule(norm, x, p) @test y_rev == y - @test back(ȳ) == (NO_FIELDS, zero(x), ZeroTangent()) - @test back(ZeroTangent()) == (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test back(ȳ) == (NoTangent(), zero(x), ZeroTangent()) + @test back(ZeroTangent()) == (NoTangent(), ZeroTangent(), ZeroTangent()) end end end @@ -185,12 +185,12 @@ end @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) test_rrule(normalize, x) - @test rrule(normalize, x)[2](ZeroTangent()) === (NO_FIELDS, ZeroTangent()) + @test rrule(normalize, x)[2](ZeroTangent()) === (NoTangent(), ZeroTangent()) end @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable x = randn(T, 3) test_rrule(normalize, x, p) - @test rrule(normalize, x, p)[2](ZeroTangent()) === (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test rrule(normalize, x, p)[2](ZeroTangent()) === (NoTangent(), ZeroTangent(), ZeroTangent()) end end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 7d574db09..fb4335975 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -23,9 +23,9 @@ # TODO: replace this with a `rrule_test` once we have that working # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24 res, pb = rrule(Diagonal, [1, 4]) - @test pb(10*res) == (NO_FIELDS, [10, 40]) + @test pb(10*res) == (NoTangent(), [10, 40]) comp = Tangent{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal - @test pb(comp) == (NO_FIELDS, [10, 40]) + @test pb(comp) == (NoTangent(), [10, 40]) end @testset "dot(x, ::Diagonal, y)" begin N = 4 @@ -76,7 +76,7 @@ y, back = rrule(diagm, M, N, ps...) @test y == diagm(M, N, ps...) ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂M === NoTangent() @test ∂N === NoTangent() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index b748eed02..27e13c3ee 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -141,8 +141,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) - @test @inferred(back(CT())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + @test @inferred(back(CT())) == (NoTangent(), ZeroTangent()) end # when value used to determine phase convention is low, the usual derivatives @@ -215,7 +215,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end end @@ -272,8 +272,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) - @test @inferred(back(CT())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + @test @inferred(back(CT())) == (NoTangent(), ZeroTangent()) end @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), @@ -295,7 +295,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end @@ -500,7 +500,7 @@ @test ∂A ≈ rrule(sin, A)[2](ΔsinA)[2] + rrule(cos, A)[2](ΔcosA)[2] ΔY2 = Tangent{typeof(Y)}(ZeroTangent(), ZeroTangent()) - @test @inferred(back(ΔY2)) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ΔY2)) === (NoTangent(), ZeroTangent()) ΔY3 = Tangent{typeof(Y)}(ΔsinA, ZeroTangent()) @test @inferred(back(ΔY3)) == rrule(sin, A)[2](ΔsinA) diff --git a/test/rulesets/Random/random.jl b/test/rulesets/Random/random.jl index 66aee3631..bde25c7b3 100644 --- a/test/rulesets/Random/random.jl +++ b/test/rulesets/Random/random.jl @@ -14,7 +14,7 @@ Random.rand(d::NormalDistribution) = d.μ + d.σ*randn() rng, pb = rrule(MersenneTwister) @test rng isa MersenneTwister - @test first(pb(10)) isa typeof(NO_FIELDS) + @test first(pb(10)) isa typeof(NoTangent()) end @testset "unary" begin rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123) From 695dc939d818c2c1ee8f915ba439229d6208b9cb Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 12:40:51 +0100 Subject: [PATCH 4/7] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 85bf06c2a..2f6e5b554 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.70" +version = "0.7.71" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 23a8ee9144f922d9799ff07d3c6b549e8352106c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 12:44:03 +0100 Subject: [PATCH 5/7] change compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2f6e5b554..b62d7981d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.9.44" +ChainRulesCore = "0.10" ChainRulesTestUtils = "0.6.8" Compat = "3.30" FiniteDifferences = "0.12.8" From 40ed6d2e24eac19ac4ecc9f9f8107840b0a18b58 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 13:41:02 +0100 Subject: [PATCH 6/7] breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b62d7981d..9471c5559 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.71" +version = "0.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 0fda69a279c81e616f6de60e46acccdf244003a2 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 13:46:04 +0100 Subject: [PATCH 7/7] remove NanMath and SpecialFunctions.jl --- src/ChainRules.jl | 17 --- src/rulesets/packages/NaNMath.jl | 38 ------- src/rulesets/packages/README.md | 7 -- src/rulesets/packages/SpecialFunctions.jl | 123 --------------------- test/rulesets/packages/NaNMath.jl | 0 test/rulesets/packages/SpecialFunctions.jl | 109 ------------------ test/runtests.jl | 10 -- 7 files changed, 304 deletions(-) delete mode 100644 src/rulesets/packages/NaNMath.jl delete mode 100644 src/rulesets/packages/README.md delete mode 100644 src/rulesets/packages/SpecialFunctions.jl delete mode 100644 test/rulesets/packages/NaNMath.jl delete mode 100644 test/rulesets/packages/SpecialFunctions.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 94289d502..7c43c8710 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -54,21 +54,4 @@ include("rulesets/LinearAlgebra/factorization.jl") include("rulesets/Random/random.jl") -# Note: The following is only required because package authors sometimes do not -# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons. -# So we define them here for them. -function __init__() - @require NaNMath="77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" begin - include("rulesets/packages/NaNMath.jl") - end - - # Note: drop SpecialFunctions dependency in next breaking release - # https://github.com/JuliaDiff/ChainRules.jl/issues/319 - @require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin - if !isdefined(SpecialFunctions, :ChainRulesCore) - include("rulesets/packages/SpecialFunctions.jl") - end - end -end - end # module diff --git a/src/rulesets/packages/NaNMath.jl b/src/rulesets/packages/NaNMath.jl deleted file mode 100644 index 56aea6055..000000000 --- a/src/rulesets/packages/NaNMath.jl +++ /dev/null @@ -1,38 +0,0 @@ -@scalar_rule(NaNMath.sin(x), NaNMath.cos(x)) -@scalar_rule(NaNMath.cos(x), -NaNMath.sin(x)) -@scalar_rule(NaNMath.asin(x), inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) -@scalar_rule(NaNMath.acos(x), -inv(NaNMath.sqrt(1 - NaNMath.pow(x, 2)))) -@scalar_rule(NaNMath.acosh(x), inv(NaNMath.sqrt(NaNMath.pow(x, 2) - 1))) -@scalar_rule(NaNMath.tan(x), 1 + Ω^2) -@scalar_rule(NaNMath.atanh(x), inv(1 - NaNMath.pow(x, 2))) -@scalar_rule(NaNMath.log(x), inv(x)) -@scalar_rule(NaNMath.log2(x), inv(x) / NaNMath.log(oftype(x, 2))) -@scalar_rule(NaNMath.log10(x), inv(x) / NaNMath.log(oftype(x, 10))) -@scalar_rule(NaNMath.log1p(x), inv(x + 1)) -@scalar_rule(NaNMath.lgamma(x), SpecialFunctions.digamma(x)) -@scalar_rule(NaNMath.sqrt(x), inv(2 * Ω)) -@scalar_rule(NaNMath.pow(x, y), (y * NaNMath.pow(x, y - 1), Ω * NaNMath.log(x))) -@scalar_rule( - NaNMath.max(x, y), - (ifelse( - (y > x) | (signbit(y) < signbit(x)), - ifelse(isnan(y), true, ZeroTangent()), - ifelse(isnan(x), ZeroTangent(), true)), - ifelse( - (y > x) | (signbit(y) < signbit(x)), - ifelse(isnan(y), ZeroTangent(), true), - ifelse(isnan(x), true, ZeroTangent())), - ) -) -@scalar_rule( - NaNMath.min(x, y), - (ifelse( - (y < x) | (signbit(y) > signbit(x)), - ifelse(isnan(y), true, ZeroTangent()), - ifelse(isnan(x), ZeroTangent(), true)), - ifelse( - (y < x) | (signbit(y) > signbit(x)), - ifelse(isnan(y), ZeroTangent(), true), - ifelse(isnan(x), true, ZeroTangent())), - ) -) diff --git a/src/rulesets/packages/README.md b/src/rulesets/packages/README.md deleted file mode 100644 index 6e989d902..000000000 --- a/src/rulesets/packages/README.md +++ /dev/null @@ -1,7 +0,0 @@ -## Package Glue Code - -In the ideal world, everyone would write ChainRules for their functions -in the packages where they are defined. -By depending only on [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) -We do not live in an ideal world, so some of those definitions live here. -In the long-term the plan is to move them out of this repo. diff --git a/src/rulesets/packages/SpecialFunctions.jl b/src/rulesets/packages/SpecialFunctions.jl deleted file mode 100644 index 55f15cf82..000000000 --- a/src/rulesets/packages/SpecialFunctions.jl +++ /dev/null @@ -1,123 +0,0 @@ -const BESSEL_ORDER_INFO = """ -derivatives of Bessel functions with respect to the order are not implemented currently: -https://github.com/JuliaMath/SpecialFunctions.jl/issues/160 -""" - -@scalar_rule(SpecialFunctions.airyai(x), SpecialFunctions.airyaiprime(x)) -@scalar_rule(SpecialFunctions.airyaiprime(x), x * SpecialFunctions.airyai(x)) -@scalar_rule(SpecialFunctions.airybi(x), SpecialFunctions.airybiprime(x)) -@scalar_rule(SpecialFunctions.airybiprime(x), x * SpecialFunctions.airybi(x)) -@scalar_rule(SpecialFunctions.besselj0(x), -SpecialFunctions.besselj1(x)) -@scalar_rule( - SpecialFunctions.besselj1(x), - (SpecialFunctions.besselj0(x) - SpecialFunctions.besselj(2, x)) / 2, -) -@scalar_rule(SpecialFunctions.bessely0(x), -SpecialFunctions.bessely1(x)) -@scalar_rule( - SpecialFunctions.bessely1(x), - (SpecialFunctions.bessely0(x) - SpecialFunctions.bessely(2, x)) / 2, -) -@scalar_rule(SpecialFunctions.dawson(x), 1 - (2 * x * Ω)) -@scalar_rule(SpecialFunctions.digamma(x), SpecialFunctions.trigamma(x)) -@scalar_rule(SpecialFunctions.erf(x), (2 / sqrt(π)) * exp(-x * x)) -@scalar_rule(SpecialFunctions.erfc(x), -(2 / sqrt(π)) * exp(-x * x)) -@scalar_rule(SpecialFunctions.erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2)) -@scalar_rule(SpecialFunctions.erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) -@scalar_rule(SpecialFunctions.erfi(x), (2 / sqrt(π)) * exp(x * x)) -@scalar_rule(SpecialFunctions.erfinv(x), (sqrt(π) / 2) * exp(Ω^2)) -@scalar_rule(SpecialFunctions.gamma(x), Ω * SpecialFunctions.digamma(x)) -@scalar_rule( - SpecialFunctions.invdigamma(x), - inv(SpecialFunctions.trigamma(SpecialFunctions.invdigamma(x))), -) -@scalar_rule(SpecialFunctions.trigamma(x), SpecialFunctions.polygamma(2, x)) - -# binary -@scalar_rule( - SpecialFunctions.besselj(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.besselj(ν - 1, x) - SpecialFunctions.besselj(ν + 1, x)) / 2 - ), -) -@scalar_rule( - SpecialFunctions.besseli(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.besseli(ν - 1, x) + SpecialFunctions.besseli(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.bessely(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.bessely(ν - 1, x) - SpecialFunctions.bessely(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.besselk(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - -(SpecialFunctions.besselk(ν - 1, x) + SpecialFunctions.besselk(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.hankelh1(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.hankelh1(ν - 1, x) - SpecialFunctions.hankelh1(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.hankelh2(ν, x), - ( - @not_implemented(BESSEL_ORDER_INFO), - (SpecialFunctions.hankelh2(ν - 1, x) - SpecialFunctions.hankelh2(ν + 1, x)) / 2, - ), -) -@scalar_rule( - SpecialFunctions.polygamma(m, x), - ( - NoTangent(), - SpecialFunctions.polygamma(m + 1, x), - ), -) -# todo: setup for common expr -@scalar_rule( - SpecialFunctions.beta(a, b), - (Ω*(SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b)), - Ω*(SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b)),) -) - -# Changes between SpecialFunctions 0.7 and 0.8 -if isdefined(SpecialFunctions, :lgamma) - # actually is the absolute value of the logorithm of gamma - @scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x)) -end - -if isdefined(SpecialFunctions, :logabsgamma) - # actually is the absolute value of the logorithm of gamma paired with sign gamma - @scalar_rule(SpecialFunctions.logabsgamma(x), SpecialFunctions.digamma(x), ZeroTangent()) -end - -if isdefined(SpecialFunctions, :loggamma) - @scalar_rule(SpecialFunctions.loggamma(x), SpecialFunctions.digamma(x)) -end - -if isdefined(SpecialFunctions, :lbeta) - # todo: setup for common expr - @scalar_rule( - SpecialFunctions.lbeta(a, b), - (SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b), - SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),) - ) -end - -if isdefined(SpecialFunctions, :logbeta) - # todo: setup for common expr - @scalar_rule( - SpecialFunctions.logbeta(a, b), - (SpecialFunctions.digamma(a) - SpecialFunctions.digamma(a + b), - SpecialFunctions.digamma(b) - SpecialFunctions.digamma(a + b),) - ) -end diff --git a/test/rulesets/packages/NaNMath.jl b/test/rulesets/packages/NaNMath.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/rulesets/packages/SpecialFunctions.jl b/test/rulesets/packages/SpecialFunctions.jl deleted file mode 100644 index 6cd9c8c0b..000000000 --- a/test/rulesets/packages/SpecialFunctions.jl +++ /dev/null @@ -1,109 +0,0 @@ -@testset "general: single input" begin - for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im) - test_scalar(SpecialFunctions.erf, x) - test_scalar(SpecialFunctions.erfc, x) - test_scalar(SpecialFunctions.erfi, x) - - test_scalar(SpecialFunctions.airyai, x) - test_scalar(SpecialFunctions.airyaiprime, x) - test_scalar(SpecialFunctions.airybi, x) - test_scalar(SpecialFunctions.airybiprime, x) - - test_scalar(SpecialFunctions.erfcx, x) - test_scalar(SpecialFunctions.dawson, x) - - if x isa Real - test_scalar(SpecialFunctions.invdigamma, x) - end - - if x isa Real && 0 < x < 1 - test_scalar(SpecialFunctions.erfinv, x) - test_scalar(SpecialFunctions.erfcinv, x) - end - - if x isa Real && x > 0 || x isa Complex - test_scalar(SpecialFunctions.gamma, x) - test_scalar(SpecialFunctions.digamma, x) - test_scalar(SpecialFunctions.trigamma, x) - end - end -end - -@testset "Bessel functions" begin - for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) - test_scalar(SpecialFunctions.besselj0, x) - test_scalar(SpecialFunctions.besselj1, x) - - isreal(x) && x < 0 && continue - - test_scalar(SpecialFunctions.bessely0, x) - test_scalar(SpecialFunctions.bessely1, x) - - for nu in (-1.5, 2.2, 4.0) - test_frule(SpecialFunctions.besseli, nu, x) - test_rrule(SpecialFunctions.besseli, nu, x) - - test_frule(SpecialFunctions.besselj, nu, x) - test_rrule(SpecialFunctions.besselj, nu, x) - - test_frule(SpecialFunctions.besselk, nu, x) - test_rrule(SpecialFunctions.besselk, nu, x) - - test_frule(SpecialFunctions.bessely, nu, x) - test_rrule(SpecialFunctions.bessely, nu, x) - - # use complex numbers in `rrule` for FiniteDifferences - test_frule(SpecialFunctions.hankelh1, nu, x) - test_rrule(SpecialFunctions.hankelh1, nu, complex(x)) - - # use complex numbers in `rrule` for FiniteDifferences - test_frule(SpecialFunctions.hankelh2, nu, x) - test_rrule(SpecialFunctions.hankelh2, nu, complex(x)) - end - end -end - -@testset "beta and logbeta" begin - test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im) - for _x in test_points, _y in test_points - # ensure all complex if any complex for FiniteDifferences - x, y = promote(_x, _y) - test_frule(SpecialFunctions.beta, x, y) - test_rrule(SpecialFunctions.beta, x, y) - - if isdefined(SpecialFunctions, :lbeta) - test_frule(SpecialFunctions.lbeta, x, y) - test_rrule(SpecialFunctions.lbeta, x, y) - end - - if isdefined(SpecialFunctions, :logbeta) - test_frule(SpecialFunctions.logbeta, x, y) - test_rrule(SpecialFunctions.logbeta, x, y) - end - end -end - -@testset "log gamma and co" begin - # It is important that we have negative numbers with both odd and even integer parts - for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) - for m in (0, 1, 2, 3) - test_frule(SpecialFunctions.polygamma, m, x) - test_rrule(SpecialFunctions.polygamma, m, x) - end - - if isdefined(SpecialFunctions, :lgamma) - test_scalar(SpecialFunctions.lgamma, x) - end - - if isdefined(SpecialFunctions, :loggamma) - isreal(x) && x < 0 && continue - test_scalar(SpecialFunctions.loggamma, x) - end - - if isdefined(SpecialFunctions, :logabsgamma) - isreal(x) || continue - test_frule(SpecialFunctions.logabsgamma, x) - test_rrule(SpecialFunctions.logabsgamma, x; output_tangent=(randn(), randn())) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 420892ed1..2adb2f02f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,15 +57,5 @@ println("Testing ChainRules.jl") include_test("rulesets/Random/random.jl") end println() - - @testset "packages" begin - include_test("rulesets/packages/NaNMath.jl") - # Note: drop SpecialFunctions dependency in next breaking release - # https://github.com/JuliaDiff/ChainRules.jl/issues/319 - if !isdefined(SpecialFunctions, :ChainRulesCore) - include_test("rulesets/packages/SpecialFunctions.jl") - end - end - println() end end