diff --git a/Project.toml b/Project.toml index 85bf06c2a..b5200218b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.70" +version = "0.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,8 +12,8 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.9.44" -ChainRulesTestUtils = "0.6.8" +ChainRulesCore = "0.10" +ChainRulesTestUtils = "0.7" Compat = "3.30" FiniteDifferences = "0.12.8" Reexport = "0.2, 1" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 94289d502..203031ce1 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -2,7 +2,6 @@ module ChainRules using Reexport @reexport using ChainRulesCore -export Zero, DoesNotExist, Composite, AbstractDifferential using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using Compat diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index f56dc9a00..b1e1207e9 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -5,7 +5,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) A_dims = size(A) function reshape_pullback(Ȳ) - return (NO_FIELDS, reshape(Ȳ, A_dims), NoTangent()) + return (NoTangent(), reshape(Ȳ, A_dims), NoTangent()) end return reshape(A, dims), reshape_pullback end @@ -15,7 +15,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) ∂A = reshape(Ȳ, A_dims) ∂dims = broadcast(_ -> NoTangent(), dims) - return (NO_FIELDS, ∂A, ∂dims...) + return (NoTangent(), ∂A, ∂dims...) end return reshape(A, dims...), reshape_pullback end @@ -28,7 +28,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) function hcat_pullback(Ȳ) Xs = (A, Bs...) ntuple(length(Bs) + 2) do full_i - full_i == 1 && return NO_FIELDS + full_i == 1 && return NoTangent() i = full_i - 1 l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) @@ -50,7 +50,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[:, pre:post] end - return (NO_FIELDS, NoTangent(), ∂As) + return (NoTangent(), NoTangent(), ∂As) end return reduce(hcat, As), reduce_hcat_pullback end @@ -68,7 +68,7 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) u = l + size(Bs[i], 1) copy(selectdim(Ȳ, 1, l+1:u)) end - return (NO_FIELDS, ∂A, ∂Bs...) + return (NoTangent(), ∂A, ∂Bs...) end return vcat(A, Bs...), vcat_pullback end @@ -81,7 +81,7 @@ function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVe pre = post - diff + 1 return ΔY[pre:post, :] end - return (NO_FIELDS, NoTangent(), ∂As) + return (NoTangent(), NoTangent(), ∂As) end return reduce(vcat, As), reduce_vcat_pullback end @@ -92,14 +92,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), NoTangent()) + return (NoTangent(), sum(Ȳ), NoTangent()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, sum(Ȳ), ntuple(_->NoTangent(), length(dims))...) + return (NoTangent(), sum(Ȳ), ntuple(_->NoTangent(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 5959005fe..a0c14a4c0 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -10,7 +10,7 @@ end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) function inv_pullback(ΔΩ) - return NO_FIELDS, -Ω' * ΔΩ * Ω' + return NoTangent(), -Ω' * ΔΩ * Ω' end return Ω, inv_pullback end @@ -26,7 +26,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(Ȳ * B'), X̄ -> mul!(X̄, Ȳ, B', true, true) @@ -48,7 +48,7 @@ function rrule( function times_pullback(Ȳ) @assert size(B, 1) === 1 # otherwise primal would have failed. return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(Ȳ * vec(B')), X̄ -> mul!(X̄, Ȳ, vec(B'), true, true) @@ -67,7 +67,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), @thunk(dot(Ȳ, B)'), InplaceableThunk( @thunk(A' * Ȳ), @@ -83,7 +83,7 @@ function rrule( ) function times_pullback(Ȳ) return ( - NO_FIELDS, + NoTangent(), InplaceableThunk( @thunk(A' * Ȳ), X̄ -> mul!(X̄, conj(A), Ȳ, true, true) @@ -118,7 +118,7 @@ function rrule( ) ) addon = if z isa Bool - DoesNotExist() + NoTangent() elseif z isa Number @thunk(sum(Ȳ)) else @@ -127,7 +127,7 @@ function rrule( dz -> sum!(dz, Ȳ; init=false) ) end - (NO_FIELDS, matmul..., addon) + (NoTangent(), matmul..., addon) end return muladd(A, B, z), muladd_pullback_1 end @@ -148,7 +148,7 @@ function rrule( @thunk(ut' .* dy), dv -> dv .+= ut' .* dy ) - (NO_FIELDS, ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy) + (NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : dy) end return muladd(ut, v, z), muladd_pullback_2 end @@ -166,7 +166,7 @@ function rrule( @thunk(vec(sum(u .* conj.(Ȳ), dims=1))'), ) addon = if z isa Bool - DoesNotExist() + NoTangent() elseif z isa Number @thunk(sum(Ȳ)) else @@ -175,7 +175,7 @@ function rrule( dz -> sum!(dz, Ȳ; init=false) ) end - (NO_FIELDS, proj..., addon) + (NoTangent(), proj..., addon) end return muladd(u, vt, z), muladd_pullback_3 end @@ -197,7 +197,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R ∂A = last(dA_pb(unthunk(dAᵀ))) ∂B = last(dA_pb(unthunk(dBᵀ))) - (NO_FIELDS, ∂A, ∂B) + (NoTangent(), ∂A, ∂B) end return C, slash_pullback end @@ -217,7 +217,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Ā end ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B + return NoTangent(), ∂A, ∂B end return Y, backslash_pullback @@ -230,7 +230,7 @@ end function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real) Y = A/b function slash_pullback(Ȳ) - return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) + return (NoTangent(), @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b)) end return Y, slash_pullback end @@ -238,7 +238,7 @@ end function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real}) Y = b\A function backslash_pullback(Ȳ) - return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) + return (NoTangent(), @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b)) end return Y, backslash_pullback end @@ -249,7 +249,7 @@ end function rrule(::typeof(-), x::AbstractArray) function negation_pullback(ȳ) - return NO_FIELDS, InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ) + return NoTangent(), InplaceableThunk(@thunk(-ȳ), ā -> ā .-= ȳ) end return -x, negation_pullback end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 9ae6bf9dd..509508012 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -10,7 +10,7 @@ frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') function rrule(::typeof(adjoint), z::Number) - adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ') + adjoint_pullback(ΔΩ) = (NoTangent(), ΔΩ') return (z', adjoint_pullback) end @@ -22,7 +22,7 @@ frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz)) function rrule(::typeof(real), z::Number) # add zero(z) to embed the real number in the same number type as z - real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z)) + real_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) + zero(z)) return (real(z), real_pullback) end @@ -33,7 +33,7 @@ end frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz)) function rrule(::typeof(imag), z::Complex) - imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im) + imag_pullback(ΔΩ) = (NoTangent(), real(ΔΩ) * im) return (imag(z), imag_pullback) end @@ -45,15 +45,15 @@ function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex end function rrule(::Type{T}, z::Complex) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), Complex(ΔΩ)) return (T(z), Complex_pullback) end function rrule(::Type{T}, x::Real) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ)) return (T(x), Complex_pullback) end function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} - Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ)) + Complex_pullback(ΔΩ) = (NoTangent(), real(ΔΩ), imag(ΔΩ)) return (T(x, y), Complex_pullback) end @@ -70,7 +70,7 @@ end function rrule(::typeof(hypot), z::Complex) Ω = hypot(z) function hypot_pullback(ΔΩ) - return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) + return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) end return (Ω, hypot_pullback) end @@ -148,7 +148,7 @@ end function rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return (x, identity_pullback) end diff --git a/src/rulesets/Base/evalpoly.jl b/src/rulesets/Base/evalpoly.jl index 8ebedb8fd..59a4e2018 100644 --- a/src/rulesets/Base/evalpoly.jl +++ b/src/rulesets/Base/evalpoly.jl @@ -15,7 +15,7 @@ if VERSION ≥ v"1.4" y, ys = _evalpoly_intermediates(x, p) function evalpoly_pullback(Δy) ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) - return NO_FIELDS, ∂x, ∂p + return NoTangent(), ∂x, ∂p end return y, evalpoly_pullback end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 2730f9f25..2ddf1b4b1 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -11,7 +11,7 @@ let ## sin function rrule(::typeof(sin), x::Number) sinx, cosx = sincos(x) - sin_pullback(Δy) = (NO_FIELDS, cosx' * Δy) + sin_pullback(Δy) = (NoTangent(), cosx' * Δy) return (sinx, sin_pullback) end @@ -23,7 +23,7 @@ let ## cos function rrule(::typeof(cos), x::Number) sinx, cosx = sincos(x) - cos_pullback(Δy) = (NO_FIELDS, -sinx' * Δy) + cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) return (cosx, cos_pullback) end @@ -75,7 +75,7 @@ let Ω = abs(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) - return (NO_FIELDS, signx * real(ΔΩ)) + return (NoTangent(), signx * real(ΔΩ)) end return Ω, abs_pullback end @@ -88,7 +88,7 @@ let function rrule(::typeof(abs2), z::Union{Real, Complex}) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) - return (NO_FIELDS, 2real(ΔΩ)*z) + return (NoTangent(), 2real(ΔΩ)*z) end return abs2(z), abs2_pullback end @@ -99,7 +99,7 @@ let end function rrule(::typeof(conj), z::Union{Real, Complex}) function conj_pullback(ΔΩ) - return (NO_FIELDS, conj(ΔΩ)) + return (NoTangent(), conj(ΔΩ)) end return conj(z), conj_pullback end @@ -115,11 +115,11 @@ let function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) - return (NO_FIELDS, ZeroTangent()) + return (NoTangent(), ZeroTangent()) end function angle_pullback(ΔΩ) Δu, Δv = reim(ΔΩ) - return (NO_FIELDS, im*Δu/ifelse(iszero(x), one(x), x)) + return (NoTangent(), im*Δu/ifelse(iszero(x), one(x), x)) # `ifelse` is applied only to denominator to ensure type-stability. end return angle(x), angle_pullback @@ -130,7 +130,7 @@ let Δu, Δv = reim(ΔΩ) # `ifelse` is applied only to denominator to ensure type-stability. n = ifelse(iszero(z), one(real(z)), abs2(z)) - return (NO_FIELDS, (-y + im*x)*Δu/n) + return (NoTangent(), (-y + im*x)*Δu/n) end return angle(z), angle_pullback end @@ -155,7 +155,7 @@ let Ω = hypot(x, y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) - return (NO_FIELDS, c * x, c * y) + return (NoTangent(), c * x, c * y) end return (Ω, hypot_pullback) end @@ -198,7 +198,7 @@ let Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, sign_pullback end @@ -214,7 +214,7 @@ let function rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) - return (NO_FIELDS, ΔΩ * y', x' * ΔΩ) + return (NoTangent(), ΔΩ * y', x' * ΔΩ) end return x * y, times_pullback end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index bc945f8f6..5f8610ff5 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -21,7 +21,7 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) getindex_add! ) īnds = broadcast(_ -> NoTangent(), inds) - return (NO_FIELDS, x̄, īnds...) + return (NoTangent(), x̄, īnds...) end return y, getindex_pullback diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index f15f223bf..33d9fa563 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -14,7 +14,7 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} @thunk(broadcast(last∘tuple, x, ȳ)), x -> x .+= ȳ ) - return (NO_FIELDS, x̄) + return (NoTangent(), x̄) end return y, sum_pullback end @@ -51,7 +51,7 @@ function rrule( @thunk(2 .* real.(ȳ) .* x), dx -> dx .+= 2 .* real.(ȳ) .* x ) - return (NO_FIELDS, NoTangent(), x_thunk) + return (NoTangent(), NoTangent(), x_thunk) end return y, sum_abs2_pullback end @@ -85,7 +85,7 @@ function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:Commutativ dx .+= conj.(y ./ x) .* dy end ) - return (NO_FIELDS, x_thunk) + return (NoTangent(), x_thunk) end return y, prod_pullback end diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 9e1a45283..c803dedb0 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -10,7 +10,7 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!) - return NO_FIELDS, Δxs, NoTangent() + return NoTangent(), Δxs, NoTangent() end return ys, partialsort_pullback @@ -28,7 +28,7 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!) - return NO_FIELDS, Δxs + return NoTangent(), Δxs end return ys, sort_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 159675a25..d40e39701 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -25,7 +25,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end - return (NO_FIELDS, NoTangent(), ∂X, NoTangent(), ∂Y, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent(), ∂Y, NoTangent()) end return Ω, blas_dot_pullback end @@ -48,19 +48,19 @@ end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) function nrm2_pullback(ΔΩ) - return NO_FIELDS, x .* (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) + return NoTangent(), x .* (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) end return Ω, nrm2_pullback end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - nrm2_pullback(::ZeroTangent) = (NO_FIELDS, NoTangent(), ZeroTangent(), NoTangent()) + nrm2_pullback(::ZeroTangent) = (NoTangent(), NoTangent(), ZeroTangent(), NoTangent()) function nrm2_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) ∂X = scal!(n, s, blascopy!(n, X, incx, _zeros(X), incx), incx) - return (NO_FIELDS, NoTangent(), ∂X, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent()) end return Ω, nrm2_pullback end @@ -78,19 +78,19 @@ end function rrule(::typeof(BLAS.asum), x) function asum_pullback(ΔΩ) - return (NO_FIELDS, _signcomp.(x) .* real(ΔΩ)) + return (NoTangent(), _signcomp.(x) .* real(ΔΩ)) end return BLAS.asum(x), asum_pullback end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - asum_pullback(::ZeroTangent) = (NO_FIELDS, NoTangent(), ZeroTangent(), NoTangent()) + asum_pullback(::ZeroTangent) = (NoTangent(), NoTangent(), ZeroTangent(), NoTangent()) function asum_pullback(ΔΩ) # BLAS.scal! requires s has the same eltype as X s = eltype(X)(real(ΔΩ)) ∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx) - return (NO_FIELDS, NoTangent(), ∂X, NoTangent()) + return (NoTangent(), NoTangent(), ∂X, NoTangent()) end return Ω, asum_pullback end @@ -135,7 +135,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x̄ -> gemv!('N', α', conj(A), ȳ, one(T), x̄) ) end - return (NO_FIELDS, NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) + return (NoTangent(), NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x) end return y, gemv_pullback end @@ -146,7 +146,7 @@ function rrule( y, inner_pullback = rrule(gemv, tA, one(T), A, x) function gemv_pullback(Ȳ) (_, dtA, _, dA, dx) = inner_pullback(Ȳ) - return (NO_FIELDS, dtA, dA, dx) + return (NoTangent(), dtA, dA, dx) end return y, gemv_pullback end @@ -249,7 +249,7 @@ function rrule( ) end end - return (NO_FIELDS, NoTangent(), NoTangent(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) + return (NoTangent(), NoTangent(), NoTangent(), @thunk(dot(C, C̄) / α'), ∂A, ∂B) end return C, gemm_pullback end @@ -260,7 +260,7 @@ function rrule( C, inner_pullback = rrule(gemm, tA, tB, one(T), A, B) function gemm_pullback(Ȳ) (_, dtA, dtB, _, dA, dB) = inner_pullback(Ȳ) - return (NO_FIELDS, dtA, dtB, dA, dB) + return (NoTangent(), dtA, dtB, dA, dB) end return C, gemm_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 5ff1764d7..30769eca3 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -8,7 +8,7 @@ end function rrule(::typeof(dot), x, y) function dot_pullback(ΔΩ) - return (NO_FIELDS, @thunk(y .* ΔΩ'), @thunk(x .* ΔΩ)) + return (NoTangent(), @thunk(y .* ΔΩ'), @thunk(x .* ΔΩ)) end return dot(x, y), dot_pullback end @@ -24,9 +24,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N dx = @thunk conj(ΔΩ) .* Ay dA = @thunk ΔΩ .* x .* adjoint(y) dy = @thunk ΔΩ .* (adjoint(A) * x) - return (NO_FIELDS, dx, dA, dy) + return (NoTangent(), dx, dA, dy) end - dot_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent(), ZeroTangent()) + dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) return z, dot_pullback end @@ -36,9 +36,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number} dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused dA = @thunk Diagonal(ΔΩ .* x .* conj(y)) # calculate N not N^2 elements dy = @thunk ΔΩ .* conj.(A.diag) .* x - return (NO_FIELDS, dx, dA, dy) + return (NoTangent(), dx, dA, dy) end - dot_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent(), ZeroTangent()) + dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) return z, dot_pullback end @@ -54,7 +54,7 @@ end function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) Ω = cross(a, b) function cross_pullback(ΔΩ) - return (NO_FIELDS, @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) + return (NoTangent(), @thunk(cross(b, ΔΩ)), @thunk(cross(ΔΩ, a))) end return Ω, cross_pullback end @@ -76,7 +76,7 @@ function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) function det_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ : Ω * ΔΩ * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, det_pullback end @@ -94,7 +94,7 @@ function rrule(::typeof(logdet), x::Union{Number, StridedMatrix{<:Number}}) Ω = logdet(x) function logdet_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ / x' : ΔΩ * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, logdet_pullback end @@ -122,7 +122,7 @@ function rrule(::typeof(logabsdet), x::AbstractMatrix) imagf = f - real(f) g = real(Δy) + imagf ∂x = g * inv(x)' - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return Ω, logabsdet_pullback end @@ -139,7 +139,7 @@ function rrule(::typeof(tr), x) # This should really be a FillArray # see https://github.com/JuliaDiff/ChainRules.jl/issues/46 function tr_pullback(ΔΩ) - return (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1)))) + return (NoTangent(), Diagonal(fill(ΔΩ, size(x, 1)))) end return tr(x), tr_pullback end @@ -202,7 +202,7 @@ function rrule( y = pinv(x, tol) function pinv_pullback(Δy) ∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y) - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end return y, pinv_pullback end @@ -216,7 +216,7 @@ function rrule( function pinv_pullback(Δy) ∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y ∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′) - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end return y, pinv_pullback end @@ -235,7 +235,7 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} ∂A = add!!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A) ∂A = add!!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y' end - return (NO_FIELDS, ∂A) + return (NoTangent(), ∂A) end return Y, pinv_pullback end @@ -280,7 +280,7 @@ function rrule( trans = T <: Complex ? 'C' : 'T' ∂D, scale2 = LAPACK.trsyl!(trans, trans, RA, RB, ∂Y) ∂Z = rmul!(QA * (∂D * QB'), -inv(scale2)) - return NO_FIELDS, @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) + return NoTangent(), @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) end return Ω, sylvester_pullback end @@ -318,7 +318,7 @@ function rrule( ∂Y = Q' * (∂Ω * Q) ∂D, scale2 = LAPACK.trsyl!(T <: Complex ? 'C' : 'T', 'N', R, R, ∂Y) ∂Z = rmul!(Q * (∂D * Q'), -inv(scale2)) - return NO_FIELDS, @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) + return NoTangent(), @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) end return Ω, lyap_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 1601bb245..109b16f80 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -77,7 +77,7 @@ function rrule( F = lu(A, pivot; kwargs...) function lu_pullback(ΔF::Tangent) Δfactors = ΔF.factors - Δfactors isa AbstractZero && return (NO_FIELDS, Δfactors, NoTangent()) + Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent()) factors = F.factors ∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors ∂A = similar(factors) @@ -127,7 +127,7 @@ function rrule( if pivot === Val(true) ∂A = ∂A[invperm(F.p), :] end - return NO_FIELDS, ∂A, NoTangent() + return NoTangent(), ∂A, NoTangent() end return F, lu_pullback end @@ -151,10 +151,10 @@ function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:Stri elseif x === :factors Matrix(ΔY) else - return (NO_FIELDS, NoTangent(), NoTangent()) + return (NoTangent(), NoTangent(), NoTangent()) end ∂F = Tangent{TF}(; factors=∂factors) - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_LU_pullback end @@ -194,7 +194,7 @@ function rrule(::typeof(inv), F::LU{<:Any,<:StridedMatrix}) triu!(rdiv!(∂factors, U')) ∂factors .+= ∂L ∂F = Tangent{typeof(F)}(; factors=∂factors) - return NO_FIELDS, ∂F + return NoTangent(), ∂F end return inv(F), inv_LU_pullback end @@ -208,7 +208,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) function svd_pullback(Ȳ::Tangent) # `getproperty` on `Tangent`s ensures we have no thunks. ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt') - return (NO_FIELDS, ∂X) + return (NoTangent(), ∂X) end return F, svd_pullback end @@ -225,7 +225,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD elseif x === :Vt C(Vt=Ȳ,) end - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_svd_pullback end @@ -301,7 +301,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ function eigen_pullback(ΔF::Tangent) λ, V = F.values, F.vectors Δλ, ΔV = ΔF.values, ΔF.vectors - ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) + ΔV isa AbstractZero && Δλ isa AbstractZero && return (NoTangent(), Δλ + ΔV) if eltype(λ) <: Real && ishermitian(A) hermA = Hermitian(A) ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) @@ -319,9 +319,9 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ ∂K[diagind(∂K)] .= Δλ ∂A = mul!(∂K, V' \ ∂K, V') end - return NO_FIELDS, T <: Real ? real(∂A) : ∂A + return NoTangent(), T <: Real ? real(∂A) : ∂A end - eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + eigen_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, eigen_pullback end @@ -409,7 +409,7 @@ function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Unio function eigvals_pullback(Δλ) ∂F = Tangent{typeof(F)}(values = Δλ) _, ∂A = eigen_back(∂F) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return λ, eigvals_pullback end @@ -432,7 +432,7 @@ end function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Tangent) - return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() + return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() end return C, cholesky_pullback end @@ -441,7 +441,7 @@ function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Boo C = cholesky(A, Val(false); check=check) function cholesky_pullback(ΔC::Tangent) Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) - return NO_FIELDS, Ā, NoTangent() + return NoTangent(), Ā, NoTangent() end return C, cholesky_pullback end @@ -459,7 +459,7 @@ function rrule( function cholesky_pullback(ΔC::Tangent) Ā, U = _cholesky_pullback_shared_code(C, ΔC) Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) - return NO_FIELDS, _symhermtype(A)(Ā), NoTangent() + return NoTangent(), _symhermtype(A)(Ā), NoTangent() end return C, cholesky_pullback end @@ -476,7 +476,7 @@ function rrule( Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - return (NO_FIELDS, UpperTriangular(Ā), NoTangent()) + return (NoTangent(), UpperTriangular(Ā), NoTangent()) end return C, cholesky_pullback end @@ -507,7 +507,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} C(U=UpperTriangular(Ȳ'),) end end - return NO_FIELDS, ∂F, NoTangent() + return NoTangent(), ∂F, NoTangent() end return getproperty(F, x), getproperty_cholesky_pullback end diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 28402e5b5..f28c0a527 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -122,7 +122,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) hermX, hermX_intermediates = _matfun(exp, hermA) function exp_pullback_hermitian(ΔX) ∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates) - return NO_FIELDS, Matrix(∂hermA) + return NoTangent(), Matrix(∂hermA) end return Matrix(hermX), exp_pullback_hermitian else @@ -133,7 +133,7 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) # the default _matfun_frechet_adjoint! ∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')' ∂A = _matfun_frechet_adjoint!(exp, ∂X, A, X, intermediates) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return X, exp_pullback end diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 810da189b..4c90cbfe1 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -52,9 +52,9 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) end ) ∂p = @thunk _normp_back_p(x, p, y, Δy) - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - norm_pullback_p(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + norm_pullback_p(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, norm_pullback_p end @@ -74,9 +74,9 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) _norm2_back!(dx, x, y, Δy) end ) - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end - norm_pullback_2(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm_pullback_2(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm_pullback_2 end @@ -100,9 +100,9 @@ function rrule(::typeof(norm), x::Number, p::Real) signx = x isa Real ? sign(x) : x * pinv(y) signx * real(Δy) end - return (NO_FIELDS, ∂x, ZeroTangent()) + return (NoTangent(), ∂x, ZeroTangent()) end - norm_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + norm_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, norm_pullback end @@ -115,9 +115,9 @@ function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) ∂p = @thunk _normp_back_p(x, p, y, Δy) - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - normp_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + normp_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, normp_pullback end @@ -158,15 +158,15 @@ end function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) - normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normMinusInf_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normMinusInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normMinusInf_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normMinusInf_pullback end function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) - normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) - normInf_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normInf_pullback(Δy) = (NoTangent(), _normInf_back(x, y, Δy)) + normInf_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normInf_pullback end @@ -189,11 +189,11 @@ end function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) y = LinearAlgebra.norm1(x) - norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + norm1_pullback(Δy) = (NoTangent(), InplaceableThunk( @thunk(_norm1_back(x, y, Δy)), dx -> _norm1_back!(dx, x, y, Δy), )) - norm1_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm1_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm1_pullback end @@ -221,11 +221,11 @@ end function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) y = LinearAlgebra.norm2(x) - norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( + norm2_pullback(Δy) = (NoTangent(), InplaceableThunk( @thunk(_norm2_back(x, y, Δy)), dx -> _norm2_back!(dx, x, y, Δy), )) - norm2_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + norm2_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, norm2_pullback end @@ -261,9 +261,9 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) ∂nrm = -dot(y, Δy) * invnrm (_, ∂xnorm, ∂p) = inner_pullback(∂nrm) ∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm - return (NO_FIELDS, ∂x, ∂p) + return (NoTangent(), ∂x, ∂p) end - normalize_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent(), ZeroTangent()) + normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent()) return y, normalize_pullback end @@ -274,8 +274,8 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}) LinearAlgebra.__normalize!(y, nrm) function normalize_pullback(Δy) ∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm) - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end - normalize_pullback(::ZeroTangent) = (NO_FIELDS, ZeroTangent()) + normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) return y, normalize_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index acfc7cdb0..5203f46c1 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -10,7 +10,7 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr function slash_pullback(Ȳ) ∂A = @thunk Ȳ / B' ∂B = @thunk _unionall_wrapper(T)(-Y' * (Ȳ / B')) - return (NO_FIELDS, ∂A, ∂B) + return (NoTangent(), ∂A, ∂B) end return Y, slash_pullback end @@ -20,7 +20,7 @@ function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMa function backslash_pullback(Ȳ) ∂A = @thunk _unionall_wrapper(T)(-(A' \ Ȳ) * Y') ∂B = @thunk A' \ Ȳ - return NO_FIELDS, ∂A, ∂B + return NoTangent(), ∂A, ∂B end return Y, backslash_pullback end @@ -31,42 +31,42 @@ end function rrule(::Type{<:Diagonal}, d::AbstractVector) function Diagonal_pullback(ȳ::AbstractMatrix) - return (NO_FIELDS, diag(ȳ)) + return (NoTangent(), diag(ȳ)) end function Diagonal_pullback(ȳ::Tangent) # TODO: Assert about the primal type in the Tangent, It should be Diagonal # infact it should be exactly the type of `Diagonal(d)` # but right now Zygote loses primal type information so we can't use it. # See https://github.com/FluxML/Zygote.jl/issues/603 - return (NO_FIELDS, ȳ.diag) + return (NoTangent(), ȳ.diag) end return Diagonal(d), Diagonal_pullback end function rrule(::typeof(diag), A::AbstractMatrix) function diag_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ)) + return (NoTangent(), Diagonal(ȳ)) end return diag(A), diag_pullback end if VERSION ≥ v"1.3" function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) function diag_pullback(ȳ) - return (NO_FIELDS, diagm(size(A)..., k => ȳ), NoTangent()) + return (NoTangent(), diagm(size(A)..., k => ȳ), NoTangent()) end return diag(A, k), diag_pullback end function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(m, n, kv...), diagm_pullback end end function rrule(::typeof(diagm), kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) - return (NO_FIELDS, _diagm_back.(kv, Ref(ȳ))...) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end return diagm(kv...), diagm_pullback end @@ -79,7 +79,7 @@ end function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) function times_pullback(Ȳ) - return (NO_FIELDS, @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) + return (NoTangent(), @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) end return D * V, times_pullback end @@ -89,26 +89,26 @@ end ##### function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) - Adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + Adjoint_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return Adjoint(A), Adjoint_pullback end function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) - Adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + Adjoint_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + Adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return Adjoint(A), Adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) - adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) + adjoint_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + adjoint_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), adjoint(ȳ)) return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) - adjoint_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) + adjoint_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + adjoint_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(adjoint(ȳ))) return adjoint(A), adjoint_pullback end @@ -117,26 +117,26 @@ end ##### function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) - Transpose_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) + Transpose_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + Transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), Transpose(ȳ)) return Transpose(A), Transpose_pullback end function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) - Transpose_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) + Transpose_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + Transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(Transpose(ȳ))) return Transpose(A), Transpose_pullback end function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) - transpose_pullback(ȳ::Tangent) = (NO_FIELDS, ȳ.parent) - transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) + transpose_pullback(ȳ::Tangent) = (NoTangent(), ȳ.parent) + transpose_pullback(ȳ::AbstractVecOrMat) = (NoTangent(), transpose(ȳ)) return transpose(A), transpose_pullback end function rrule(::typeof(transpose), A::AbstractVector{<:Number}) - transpose_pullback(ȳ::Tangent) = (NO_FIELDS, vec(ȳ.parent)) - transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) + transpose_pullback(ȳ::Tangent) = (NoTangent(), vec(ȳ.parent)) + transpose_pullback(ȳ::AbstractMatrix) = (NoTangent(), vec(transpose(ȳ))) return transpose(A), transpose_pullback end @@ -146,40 +146,40 @@ end function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) function UpperTriangular_pullback(ȳ) - return (NO_FIELDS, Matrix(ȳ)) + return (NoTangent(), Matrix(ȳ)) end return UpperTriangular(A), UpperTriangular_pullback end function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) function LowerTriangular_pullback(ȳ) - return (NO_FIELDS, Matrix(ȳ)) + return (NoTangent(), Matrix(ȳ)) end return LowerTriangular(A), LowerTriangular_pullback end function rrule(::typeof(triu), A::AbstractMatrix, k::Integer) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ, k), NoTangent()) + return (NoTangent(), triu(ȳ, k), NoTangent()) end return triu(A, k), triu_pullback end function rrule(::typeof(triu), A::AbstractMatrix) function triu_pullback(ȳ) - return (NO_FIELDS, triu(ȳ)) + return (NoTangent(), triu(ȳ)) end return triu(A), triu_pullback end function rrule(::typeof(tril), A::AbstractMatrix, k::Integer) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ, k), NoTangent()) + return (NoTangent(), tril(ȳ, k), NoTangent()) end return tril(A, k), tril_pullback end function rrule(::typeof(tril), A::AbstractMatrix) function tril_pullback(ȳ) - return (NO_FIELDS, tril(ȳ)) + return (NoTangent(), tril(ȳ)) end return tril(A), tril_pullback end @@ -191,7 +191,7 @@ function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) y = det(X) s = conj!(y ./ _diag_view(X)) function det_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ .* s)) + return (NoTangent(), Diagonal(ȳ .* s)) end return y, det_pullback end @@ -200,7 +200,7 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) y = logdet(X) s = conj!(one(eltype(X)) ./ _diag_view(X)) function logdet_pullback(ȳ) - return (NO_FIELDS, Diagonal(ȳ .* s)) + return (NoTangent(), Diagonal(ȳ .* s)) end return y, logdet_pullback end diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 0444993e1..6576a1820 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -9,7 +9,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) @inline function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) + return (NoTangent(), _symherm_back(typeof(Ω), ΔΩ, uplo), NoTangent()) end return Ω, HermOrSym_pullback end @@ -27,7 +27,7 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} uplo = A.uplo ∂A = T∂A(_symherm_back(typeof(A), ΔΩ, Symbol(uplo)), uplo) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return TM(A), Matrix_pullback end @@ -133,9 +133,9 @@ function rrule( Δλ, ΔU = ΔF.values, ΔF.vectors ΔU = ΔU isa AbstractZero ? ΔU : copy(ΔU) ∂A = eigen_rev!(A, λ, U, Δλ, ΔU) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end - eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + eigen_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, eigen_pullback end @@ -226,7 +226,7 @@ function rrule( function eigvals_pullback(Δλ) ∂F = Tangent{typeof(F)}(values = Δλ) _, ∂A = eigen_back(∂F) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return λ, eigvals_pullback end @@ -251,9 +251,9 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.Bla ∂U = ΔF.U .+ (ΔF.V .+ ΔF.Vt') .* c' end ∂A = eigen_rev!(A, λ, U, ∂λ, ∂U) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end - svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + svd_pullback(ΔF::AbstractZero) = (NoTangent(), ΔF) return F, svd_pullback end @@ -286,9 +286,9 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS invpermute!(∂λ, p) ∂λ .*= sign.(λ) _, ∂A = back(∂λ) - return NO_FIELDS, unthunk(∂A) + return NoTangent(), unthunk(∂A) end - svdvals_pullback(ΔS::AbstractZero) = (NO_FIELDS, ΔS) + svdvals_pullback(ΔS::AbstractZero) = (NoTangent(), ΔS) return S, svdvals_pullback end @@ -324,7 +324,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates) # the cotangent of Hermitian A should be Hermitian ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, $(Symbol(func, :_pullback)) end @@ -351,7 +351,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo) Y = (sinA, cosA) function sincos_pullback((ΔsinA, ΔcosA)::Tangent) - ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA + ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NoTangent(), ΔsinA + ΔcosA if eltype(A) <: Real ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) end @@ -369,7 +369,7 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U')) end ∂A = _hermitrizelike!(Ā, A) - return NO_FIELDS, ∂A + return NoTangent(), ∂A end return Y, sincos_pullback end diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index c68741987..20fa10971 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -2,7 +2,7 @@ frule(Δargs, ::Type{MersenneTwister}, args...) = MersenneTwister(args...), Zero function rrule(::Type{MersenneTwister}, args...) function MersenneTwister_pullback(ΔΩ) - return (NO_FIELDS, map(_ -> ZeroTangent(), args)...) + return (NoTangent(), map(_ -> ZeroTangent(), args)...) end return MersenneTwister(args...), MersenneTwister_pullback end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 6b15b10fa..6f3f72037 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -14,7 +14,7 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) function mean_pullback(ȳ) _, ∂sum_x = sum_pullback(ȳ) ∂x = extern(∂sum_x) / n - return (NO_FIELDS, ∂x) + return (NoTangent(), ∂x) end return y_sum / n, mean_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 74a634238..91203fb86 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -165,12 +165,12 @@ end @testset "type" begin - @test frule((NO_FIELDS, NoTangent(), NoTangent()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing + @test frule((NoTangent(), NoTangent(), NoTangent()), typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing @test rrule(typejoin, Array{Float32,4}, Array{Float32,3}) !== nothing end @testset "Logging" begin - @test frule((NO_FIELDS, NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing + @test frule((NoTangent(), NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing @test rrule(Base.depwarn, "message", :f) !== nothing end end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 805b0ead7..397f1622f 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -187,12 +187,12 @@ end @test X̄_vectors_ad ≈ j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1] rtol=1e-4 F̄ = CT(values = λ̄, vectors = V̄) s̄elf, X̄_ad = @inferred back(F̄) - @test s̄elf === NO_FIELDS + @test s̄elf === NoTangent() X̄_fd = j′vp(_fdm, asnt ∘ eigen, F̄, X)[1] @test X̄_ad ≈ X̄_fd rtol=1e-4 - @test @inferred(back(ZeroTangent())) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) === (NoTangent(), ZeroTangent()) F̄zero = CT(values = ZeroTangent(), vectors = ZeroTangent()) - @test @inferred(back(F̄zero)) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(F̄zero)) === (NoTangent(), ZeroTangent()) T <: Real && @testset "cotangent is real when input is" begin X = randn(T, n, n) @@ -286,7 +286,7 @@ end end ∂self, ∂A = @inferred back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) ∂A_fd = j′vp(_fdm, f_stable, ∂F_stable, A)[1] @test ∂A ≈ ∂A_fd @@ -318,7 +318,7 @@ end λ, back = rrule(eigvals, randn(T, n, n)) _, X̄ = @inferred back(rand_tangent(λ)) - @test @inferred(back(ZeroTangent())) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) === (NoTangent(), ZeroTangent()) T <: Real && @testset "cotangent is real when input is" begin @test eltype(X̄) <: Real @@ -343,10 +343,10 @@ end λ_ad, back = rrule(eigvals, A) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂A = @inferred back(Δλ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end end @@ -375,7 +375,7 @@ end Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NO_FIELDS + @test dself === NoTangent() @test dp === NoTangent() # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index ebd7b0908..4420b34c7 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -159,7 +159,7 @@ test_rrule(norm, randn(T), p) _, back = rrule(norm, randn(T), p) - @test back(ZeroTangent()) == (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test back(ZeroTangent()) == (NoTangent(), ZeroTangent(), ZeroTangent()) end @testset "p = 0" begin p = 0.0 @@ -172,8 +172,8 @@ @test iszero(ẏ) y_rev, back = rrule(norm, x, p) @test y_rev == y - @test back(ȳ) == (NO_FIELDS, zero(x), ZeroTangent()) - @test back(ZeroTangent()) == (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test back(ȳ) == (NoTangent(), zero(x), ZeroTangent()) + @test back(ZeroTangent()) == (NoTangent(), ZeroTangent(), ZeroTangent()) end end end @@ -185,12 +185,12 @@ end @testset "x::Vector{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) test_rrule(normalize, x) - @test rrule(normalize, x)[2](ZeroTangent()) === (NO_FIELDS, ZeroTangent()) + @test rrule(normalize, x)[2](ZeroTangent()) === (NoTangent(), ZeroTangent()) end @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable x = randn(T, 3) test_rrule(normalize, x, p) - @test rrule(normalize, x, p)[2](ZeroTangent()) === (NO_FIELDS, ZeroTangent(), ZeroTangent()) + @test rrule(normalize, x, p)[2](ZeroTangent()) === (NoTangent(), ZeroTangent(), ZeroTangent()) end end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index cb92b9119..fb4335975 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -23,9 +23,9 @@ # TODO: replace this with a `rrule_test` once we have that working # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24 res, pb = rrule(Diagonal, [1, 4]) - @test pb(10*res) == (NO_FIELDS, [10, 40]) + @test pb(10*res) == (NoTangent(), [10, 40]) comp = Tangent{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal - @test pb(comp) == (NO_FIELDS, [10, 40]) + @test pb(comp) == (NoTangent(), [10, 40]) end @testset "dot(x, ::Diagonal, y)" begin N = 4 @@ -57,7 +57,7 @@ y, back = rrule(diagm, ps...) @test y == diagm(ps...) ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) ∂px = unthunk(∂px) @@ -76,7 +76,7 @@ y, back = rrule(diagm, M, N, ps...) @test y == diagm(M, N, ps...) ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂M === NoTangent() @test ∂N === NoTangent() ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> diagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 1b5a006ba..27e13c3ee 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -133,7 +133,7 @@ end ∂self, ∂symA = @inferred back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @@ -141,8 +141,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) - @test @inferred(back(CT())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + @test @inferred(back(CT())) == (NoTangent(), ZeroTangent()) end # when value used to determine phase convention is low, the usual derivatives @@ -207,7 +207,7 @@ λ_ad, back = @inferred rrule(eigvals, symA) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂symA = @inferred back(Δλ) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo @@ -215,7 +215,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end end @@ -264,7 +264,7 @@ VERSION ≥ v"1.6.0-DEV.1686" && @inferred back(∂F) ∂self, ∂symA = back(∂F) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @@ -272,8 +272,8 @@ @test ∂A ≈ ∂A_fd end - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) - @test @inferred(back(CT())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) + @test @inferred(back(CT())) == (NoTangent(), ZeroTangent()) end @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), @@ -287,7 +287,7 @@ S_ad, back = @inferred rrule(svdvals, symA) @test S_ad ≈ S # inexact because rrule uses svd not svdvals ∂self, ∂symA = @inferred back(ΔS) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo @@ -295,7 +295,7 @@ ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] - @test @inferred(back(ZeroTangent())) == (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end @@ -412,7 +412,7 @@ @test typeof(Y_ad) === typeof(Y) hasproperty(Y, :uplo) && @test Y_ad.uplo == Y.uplo ∂self, ∂A = @inferred back(ΔY) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A isa typeof(A) @test ∂A.uplo == A.uplo # check pullback composes correctly @@ -496,11 +496,11 @@ ΔY = Tangent{typeof(Y)}(ΔsinA, ΔcosA) ∂self, ∂A = @inferred back(ΔY) - @test ∂self === NO_FIELDS + @test ∂self === NoTangent() @test ∂A ≈ rrule(sin, A)[2](ΔsinA)[2] + rrule(cos, A)[2](ΔcosA)[2] ΔY2 = Tangent{typeof(Y)}(ZeroTangent(), ZeroTangent()) - @test @inferred(back(ΔY2)) === (NO_FIELDS, ZeroTangent()) + @test @inferred(back(ΔY2)) === (NoTangent(), ZeroTangent()) ΔY3 = Tangent{typeof(Y)}(ΔsinA, ZeroTangent()) @test @inferred(back(ΔY3)) == rrule(sin, A)[2](ΔsinA) diff --git a/test/rulesets/Random/random.jl b/test/rulesets/Random/random.jl index 66aee3631..bde25c7b3 100644 --- a/test/rulesets/Random/random.jl +++ b/test/rulesets/Random/random.jl @@ -14,7 +14,7 @@ Random.rand(d::NormalDistribution) = d.μ + d.σ*randn() rng, pb = rrule(MersenneTwister) @test rng isa MersenneTwister - @test first(pb(10)) isa typeof(NO_FIELDS) + @test first(pb(10)) isa typeof(NoTangent()) end @testset "unary" begin rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123)