diff --git a/src/compute_jacobian_ad.jl b/src/compute_jacobian_ad.jl index fbfee2c5..fde48049 100644 --- a/src/compute_jacobian_ad.jl +++ b/src/compute_jacobian_ad.jl @@ -1,4 +1,4 @@ -using ForwardDiff: Dual, jacobian, partials +using ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD struct ForwardColorJacCache{T,T2,T3,T4,T5} t::T @@ -8,23 +8,50 @@ struct ForwardColorJacCache{T,T2,T3,T4,T5} color::T5 end -function ForwardColorJacCache(f,x; +function default_chunk_size(maxcolor) + if maxcolor < DEFAULT_CHUNK_THRESHOLD + Val(maxcolor) + else + Val(DEFAULT_CHUNK_THRESHOLD) + end +end + +getsize(::Val{N}) where N = N +getsize(N::Integer) = N + +function ForwardColorJacCache(f,x,_chunksize = nothing; dx = nothing, color=1:length(x)) - t = zeros(Dual{typeof(f), eltype(x), maximum(color)},length(x)) + if _chunksize === nothing + chunksize = default_chunk_size(maximum(color)) + else + chunksize = _chunksize + end + + t = zeros(Dual{typeof(f), eltype(x), getsize(chunksize)},length(x)) if dx === nothing fx = similar(t) _dx = similar(x) else - fx = zeros(Dual{typeof(f), eltype(dx), maximum(color)},length(dx)) + fx = zeros(Dual{typeof(f), eltype(dx), getsize(chunksize)},length(dx)) _dx = dx end - partials_array = Array{Bool}(undef, length(x), maximum(color)) + p = generate_chunked_partials(x,color,chunksize) + ForwardColorJacCache(t,fx,_dx,p,color) +end + +generate_chunked_partials(x,color,N::Integer) = generate_chunked_partials(x,color,Val(N)) +function generate_chunked_partials(x,color,::Val{N}) where N + + # TODO: should only go up to the chunksize each time, and should + # generate p[i] different parts, each with less than the chunksize + + partials_array = BitMatrix(undef, length(x), maximum(color)) for color_i in 1:maximum(color) - for i in 1:length(x) + for i in eachindex(x) if color[i]==color_i partials_array[i,color_i] = true else @@ -33,22 +60,20 @@ function ForwardColorJacCache(f,x; end end p = Tuple.(eachrow(partials_array)) - - ForwardColorJacCache(t,fx,_dx,p,color) end function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}; dx = nothing, - color) + color = eachindex(x)) forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color)) end function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, - jac_cache = ForwardColorJacCache(f,x)) + jac_cache::ForwardColorJacCache) t = jac_cache.t fx = jac_cache.fx @@ -56,6 +81,7 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, p = jac_cache.p color = jac_cache.color + # TODO: Should compute on each p[i] and decompress t .= Dual{typeof(f)}.(x, p) f(fx, t)