Skip to content

Commit 4fa1add

Browse files
committed
implement in-place tuple derivative function
1 parent d01378a commit 4fa1add

File tree

2 files changed

+71
-25
lines changed

2 files changed

+71
-25
lines changed

src/config.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22
# Tag #
33
#######
44

5-
struct Tag{F,M} end
5+
struct Tag{F,H} end
6+
7+
# Here, we could've just as easily used `hash`; however, this
8+
# is unsafe/undefined behavior if `hash(::Type{V})` is overloaded
9+
# in a module loaded after ForwardDiff. Thus, we instead use
10+
# `hash(Symbol(V))`, which is somewhat safer since it's far less
11+
# likely that somebody would overwrite the Base definition for
12+
# `Symbol(::DataType)` or `hash(::Symbol)`.
13+
@generated function Tag(::Type{F}, ::Type{V}) where {F,V}
14+
H = hash(Symbol(V))
15+
return quote
16+
$(Expr(:meta, :inline))
17+
Tag{F,$H}()
18+
end
19+
end
620

721
#########
822
# Chunk #
@@ -37,9 +51,9 @@ end
3751

3852
abstract type AbstractConfig{T<:Tag,N} end
3953

40-
struct ConfigMismatchError{F,G,M} <: Exception
54+
struct ConfigMismatchError{F,G,H} <: Exception
4155
f::F
42-
cfg::AbstractConfig{Tag{G,M}}
56+
cfg::AbstractConfig{Tag{G,H}}
4357
end
4458

4559
function Base.showerror{F,G}(io::IO, e::ConfigMismatchError{F,G})
@@ -67,7 +81,7 @@ end
6781
function GradientConfig{V,N,F,T}(::F,
6882
x::AbstractArray{V},
6983
::Chunk{N} = Chunk(x),
70-
::T = Tag{F,order(V)}())
84+
::T = Tag(F, V))
7185
seeds = construct_seeds(Partials{N,V})
7286
duals = similar(x, Dual{T,V,N})
7387
return GradientConfig{T,V,N,typeof(duals)}(seeds, duals)
@@ -85,7 +99,7 @@ end
8599
function JacobianConfig{V,N,F,T}(::F,
86100
x::AbstractArray{V},
87101
::Chunk{N} = Chunk(x),
88-
::T = Tag{F,order(V)}())
102+
::T = Tag(F, V))
89103
seeds = construct_seeds(Partials{N,V})
90104
duals = similar(x, Dual{T,V,N})
91105
return JacobianConfig{T,V,N,typeof(duals)}(seeds, duals)
@@ -95,7 +109,7 @@ function JacobianConfig{Y,X,N,F,T}(::F,
95109
y::AbstractArray{Y},
96110
x::AbstractArray{X},
97111
::Chunk{N} = Chunk(x),
98-
::T = Tag{F,order(X)}())
112+
::T = Tag(F, X))
99113
seeds = construct_seeds(Partials{N,X})
100114
yduals = similar(y, Dual{T,Y,N})
101115
xduals = similar(x, Dual{T,X,N})
@@ -107,15 +121,15 @@ end
107121
# HessianConfig #
108122
#################
109123

110-
struct HessianConfig{T,V,N,D,MJ,DJ} <: AbstractConfig{T,N}
111-
jacobian_config::JacobianConfig{Tag{Void,MJ},V,N,DJ}
112-
gradient_config::GradientConfig{T,Dual{Tag{Void,MJ},V,N},D}
124+
struct HessianConfig{T,V,N,D,H,DJ} <: AbstractConfig{T,N}
125+
jacobian_config::JacobianConfig{Tag{Void,H},V,N,DJ}
126+
gradient_config::GradientConfig{T,Dual{Tag{Void,H},V,N},D}
113127
end
114128

115129
function HessianConfig{F,V}(f::F,
116130
x::AbstractArray{V},
117131
chunk::Chunk = Chunk(x),
118-
tag::Tag = Tag{F,order(Dual{Void,V,0})}())
132+
tag::Tag = Tag(F, Dual{Void,V,0}))
119133
jacobian_config = JacobianConfig(nothing, x, chunk)
120134
gradient_config = GradientConfig(f, jacobian_config.duals, chunk, tag)
121135
return HessianConfig(jacobian_config, gradient_config)
@@ -125,7 +139,7 @@ function HessianConfig{F,V}(result::DiffResult,
125139
f::F,
126140
x::AbstractArray{V},
127141
chunk::Chunk = Chunk(x),
128-
tag::Tag = Tag{F,order(Dual{Void,V,0})}())
142+
tag::Tag = Tag(F, Dual{Void,V,0}))
129143
jacobian_config = JacobianConfig(nothing, DiffBase.gradient(result), x, chunk)
130144
gradient_config = GradientConfig(f, jacobian_config.duals[2], chunk, tag)
131145
return HessianConfig(jacobian_config, gradient_config)

src/derivative.jl

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,43 @@
22
# API methods #
33
###############
44

5-
@generated function derivative{F,R<:Real}(f::F, x::R)
6-
T = Tag{F,order(R)}
7-
return quote
8-
$(Expr(:meta, :inline))
9-
return extract_derivative(f(Dual{$T}(x, one(x))))
10-
end
5+
@inline function derivative(f::F, x::R) where {F,R<:Real}
6+
T = Tag(F, R)
7+
return extract_derivative(f(Dual{T}(x, one(x))))
118
end
129

13-
@generated function derivative{F,N}(f::F, x::NTuple{N,Real})
14-
T = Tag{F,maximum(order(R) for R in x.parameters)}
15-
args = [:(Dual{$T}(x[$i], Val{N}, Val{$i})) for i in 1:N]
10+
@generated function derivative(f::F, x::NTuple{N,Real}) where {F,N}
11+
args = [:(Dual{T}(x[$i], Val{N}, Val{$i})) for i in 1:N]
1612
return quote
1713
$(Expr(:meta, :inline))
14+
T = Tag(F, typeof(x))
1815
extract_derivative(f($(args...)))
1916
end
2017
end
2118

22-
@generated function derivative!{F,R<:Real}(out, f::F, x::R)
23-
T = Tag{F,order(R)}
19+
@inline function derivative!(out, f::F, x::R) where {F,R<:Real}
20+
T = Tag(F, typeof(x))
21+
extract_derivative!(out, f(Dual{T}(x, one(x))))
22+
return out
23+
end
24+
25+
@generated function derivative!(out::NTuple{N,Any}, f::F, x::NTuple{N,Real}) where {F,N}
26+
args = [:(Dual{T}(x[$i], Val{N}, Val{$i})) for i in 1:N]
2427
return quote
2528
$(Expr(:meta, :inline))
26-
extract_derivative!(out, f(Dual{$T}(x, one(x))))
27-
return out
29+
T = Tag(F, typeof(x))
30+
extract_derivative!(out, f($(args...)))
2831
end
2932
end
3033

3134
#####################
3235
# result extraction #
3336
#####################
3437

35-
@generated function extract_derivative{T,V,N}(y::Dual{T,V,N})
38+
# non-mutating #
39+
#--------------#
40+
41+
@generated function extract_derivative(y::Dual{T,V,N}) where {T,V,N}
3642
return quote
3743
$(Expr(:meta, :inline))
3844
$(Expr(:tuple, [:(partials(y, $i)) for i in 1:N]...))
@@ -43,10 +49,36 @@ end
4349
@inline extract_derivative(y::Real) = zero(y)
4450
@inline extract_derivative(y::AbstractArray) = extract_derivative!(similar(y, valtype(eltype(y))), y)
4551

52+
# mutating #
53+
#----------#
54+
55+
@generated function extract_derivative!(out::NTuple{N,Any}, y::Dual{T,V,N}) where {T,V,N}
56+
return quote
57+
$(Expr(:meta, :inline))
58+
$(Expr(:block, [:(out[$i][] = partials(y, $i)) for i in 1:N]...))
59+
return out
60+
end
61+
end
62+
63+
@generated function extract_derivative!(out::NTuple{N,Any}, y::AbstractArray) where {N}
64+
return quote
65+
$(Expr(:meta, :inline))
66+
$(Expr(:block, [:(extract_derivative!(out[$i], y, $i)) for i in 1:N]...))
67+
return out
68+
end
69+
end
70+
4671
extract_derivative!(out::AbstractArray, y::AbstractArray) = map!(extract_derivative, out, y)
72+
extract_derivative!(out::AbstractArray, y::AbstractArray, p) = map!(x -> partials(x, p), out, y)
4773

4874
function extract_derivative!(out::DiffResult, y)
4975
DiffBase.value!(value, out, y)
5076
DiffBase.derivative!(extract_derivative, out, y)
5177
return out
5278
end
79+
80+
function extract_derivative!(out::DiffResult, y::AbstractArray, p)
81+
DiffBase.value!(value, out, y)
82+
DiffBase.derivative!(x -> partials(x, p), out, y)
83+
return out
84+
end

0 commit comments

Comments
 (0)