From 57fbfae4dcdd8bb6f80ed5c3d1a006a278dd1357 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 31 May 2024 11:06:57 -0400 Subject: [PATCH 01/33] Fresh attempt at DI integration --- Project.toml | 23 +- ext/OptimizationDIExt.jl | 372 +++++++++++++ ext/OptimizationFiniteDiffExt.jl | 470 ----------------- ext/OptimizationForwardDiffExt.jl | 341 ------------ ext/OptimizationMTKExt.jl | 10 +- ext/OptimizationReverseDiffExt.jl | 581 --------------------- ext/OptimizationSparseDiffExt.jl | 31 -- ext/OptimizationSparseFiniteDiff.jl | 541 ------------------- ext/OptimizationSparseForwardDiff.jl | 459 ---------------- ext/OptimizationSparseReverseDiff.jl | 751 --------------------------- ext/OptimizationTrackerExt.jl | 72 --- ext/OptimizationZygoteExt.jl | 2 +- 12 files changed, 382 insertions(+), 3271 deletions(-) create mode 100644 ext/OptimizationDIExt.jl delete mode 100644 ext/OptimizationFiniteDiffExt.jl delete mode 100644 ext/OptimizationForwardDiffExt.jl delete mode 100644 ext/OptimizationReverseDiffExt.jl delete mode 100644 ext/OptimizationSparseDiffExt.jl delete mode 100644 ext/OptimizationSparseFiniteDiff.jl delete mode 100644 ext/OptimizationSparseForwardDiff.jl delete mode 100644 ext/OptimizationSparseReverseDiff.jl delete mode 100644 ext/OptimizationTrackerExt.jl diff --git a/Project.toml b/Project.toml index 57851d6..0dbcfb2 100644 --- a/Project.toml +++ b/Project.toml @@ -17,41 +17,26 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" [weakdeps] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -OptimizationEnzymeExt = "Enzyme" -OptimizationFiniteDiffExt = "FiniteDiff" -OptimizationForwardDiffExt = "ForwardDiff" +OptimizationDIExt = "DifferentiationInterface" OptimizationMTKExt = "ModelingToolkit" -OptimizationReverseDiffExt = "ReverseDiff" -OptimizationSparseDiffExt = ["SparseDiffTools", "ReverseDiff"] -OptimizationTrackerExt = "Tracker" -OptimizationZygoteExt = "Zygote" +OptimizationZygoteExt = ["Zygote", "DifferentiationInterface"] [compat] ADTypes = "1.3" ArrayInterface = "7.6" DocStringExtensions = "0.9" -Enzyme = "0.12.12" -FiniteDiff = "2.12" -ForwardDiff = "0.10.26" LinearAlgebra = "1.9, 1.10" -Manifolds = "0.9" ModelingToolkit = "9" -PDMats = "0.11" Reexport = "1.2" Requires = "1" -ReverseDiff = "1.14" SciMLBase = "2" SparseDiffTools = "2.14" SymbolicAnalysis = "0.1, 0.2" diff --git a/ext/OptimizationDIExt.jl b/ext/OptimizationDIExt.jl new file mode 100644 index 0000000..186327d --- /dev/null +++ b/ext/OptimizationDIExt.jl @@ -0,0 +1,372 @@ +module OptimizationDIExt + +import OptimizationBase, OptimizationBase.ArrayInterface +import OptimizationBase.SciMLBase: OptimizationFunction +import OptimizationBase.LinearAlgebra: I +import DifferentiationInterface +import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, + gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian +using ADTypes + +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(res, θ) + gradient!(_f, res, adtype, θ, extras_grad) + end + else + grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(res, θ, args...) + hessian!(_f, res, adtype, θ, extras_hess) + end + else + hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (H, θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp!(_f, H, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (J, θ) + jacobian!(cons_oop, J, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(H, θ) + for i in 1:num_cons + hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) + x = cache.u0 + p = cache.p + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(res, θ) + gradient!(_f, res, adtype, θ, extras_grad) + end + else + grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) + function hess(res, θ, args...) + hessian!(_f, res, soadtype, θ, extras_hess) + end + else + hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (H, θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp!(_f, H, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (J, θ) + jacobian!(cons_oop, J, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(H, θ) + for i in 1:num_cons + hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + + +function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(θ) + gradient(_f, adtype, θ, extras_grad) + end + else + grad = (θ, args...) -> f.grad(θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(θ, args...) + hessian(_f, adtype, θ, extras_hess) + end + else + hess = (θ, args...) -> f.hess(θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp(_f, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (θ) -> f.cons(θ, p) + cons_oop = cons + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (θ) + J = jacobian(cons_oop, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + return J + end + else + cons_j = (θ) -> f.cons_j(θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(θ) + H = map(1:num_cons) do i + hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + end + return H + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + +function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) + x = cache.u0 + p = cache.p + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(θ) + gradient(_f, adtype, θ, extras_grad) + end + else + grad = (θ, args...) -> f.grad(θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(θ, args...) + hessian(_f, soadtype, θ, extras_hess) + end + else + hess = (θ, args...) -> f.hess(θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp(_f, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (θ) -> f.cons(θ, p) + cons_oop = cons + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (θ) + J = jacobian(cons_oop, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + return J + end + else + cons_j = (θ) -> f.cons_j(θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(θ) + H = map(1:num_cons) do i + hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + end + return H + end + else + cons_h = (θ) -> f.cons_h(θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + +end \ No newline at end of file diff --git a/ext/OptimizationFiniteDiffExt.jl b/ext/OptimizationFiniteDiffExt.jl deleted file mode 100644 index 641f99c..0000000 --- a/ext/OptimizationFiniteDiffExt.jl +++ /dev/null @@ -1,470 +0,0 @@ -module OptimizationFiniteDiffExt - -import OptimizationBase, OptimizationBase.ArrayInterface -import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.ADTypes: AutoFiniteDiff -using OptimizationBase.LinearAlgebra -isdefined(Base, :get_extension) ? (using FiniteDiff) : (using ..FiniteDiff) - -const FD = FiniteDiff - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoFiniteDiff, p, - num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - updatecache = (cache, x) -> (cache.xmm .= x; cache.xmp .= x; cache.xpm .= x; cache.xpp .= x; return cache) - - if f.grad === nothing - gradcache = FD.GradientCache(x, x, adtype.fdtype) - grad = (res, θ, args...) -> FD.finite_difference_gradient!( - res, x -> _f(x, args...), - θ, gradcache) - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - hesscache = FD.HessianCache(x, adtype.fdhtype) - hess = (res, θ, args...) -> FD.finite_difference_hessian!(res, - x -> _f(x, args...), θ, - updatecache(hesscache, θ)) - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - T = eltype(θ) - ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(θ))) - @. θ += ϵ * v - cache2 = similar(θ) - grad(cache2, θ, args...) - @. θ -= 2ϵ * v - cache3 = similar(θ) - grad(cache3, θ, args...) - @. θ += ϵ * v - @. H = (cache2 - cache3) / (2ϵ) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - end - - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? (1:length(x)) : - f.cons_jac_colorvec - - if cons !== nothing && f.cons_j === nothing - cons_j = function (J, θ) - y0 = zeros(num_cons) - jaccache = FD.JacobianCache(copy(x), copy(y0), copy(y0), adtype.fdjtype; - colorvec = cons_jac_colorvec, - sparsity = f.cons_jac_prototype) - FD.finite_difference_jacobian!(J, cons, θ, jaccache) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - if cons !== nothing && f.cons_h === nothing - hess_cons_cache = [FD.HessianCache(copy(x), adtype.fdhtype) - for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons#note: colorvecs not yet supported by FiniteDiff for Hessians - FD.finite_difference_hessian!(res[i], - (x) -> (_res = zeros(eltype(θ), num_cons); - cons(_res, x); - _res[i]), θ, - updatecache(hess_cons_cache[i], θ)) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_hess_cache = FD.HessianCache(copy(x), adtype.fdhtype) - c = zeros(num_cons) - h = zeros(length(x), length(x)) - lag_h = let c = c, h = h - lag = function (θ, σ, μ) - f.cons(c, θ, p) - l = μ'c - if !iszero(σ) - l += σ * f.f(θ, p) - end - l - end - function (res, θ, σ, μ) - FD.finite_difference_hessian!(res, - (x) -> lag(x, σ, μ), - θ, - updatecache(lag_hess_cache, θ)) - end - end - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - cons_jac_colorvec = cons_jac_colorvec, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoFiniteDiff, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - updatecache = (cache, x) -> (cache.xmm .= x; cache.xmp .= x; cache.xpm .= x; cache.xpp .= x; return cache) - - if f.grad === nothing - gradcache = FD.GradientCache(cache.u0, cache.u0, adtype.fdtype) - grad = (res, θ, args...) -> FD.finite_difference_gradient!( - res, x -> _f(x, args...), - θ, gradcache) - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - if f.hess === nothing - hesscache = FD.HessianCache(cache.u0, adtype.fdhtype) - hess = (res, θ, args...) -> FD.finite_difference_hessian!(res, x -> _f(x, args...), - θ, - updatecache(hesscache, θ)) - else - hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - T = eltype(θ) - ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(θ))) - @. θ += ϵ * v - cache2 = similar(θ) - grad(cache2, θ, args...) - @. θ -= 2ϵ * v - cache3 = similar(θ) - grad(cache3, θ, args...) - @. θ += ϵ * v - @. H = (cache2 - cache3) / (2ϵ) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, cache.p) - end - - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? (1:length(cache.u0)) : - f.cons_jac_colorvec - - if cons !== nothing && f.cons_j === nothing - cons_j = function (J, θ) - y0 = zeros(num_cons) - jaccache = FD.JacobianCache(copy(cache.u0), copy(y0), copy(y0), - adtype.fdjtype; - colorvec = cons_jac_colorvec, - sparsity = f.cons_jac_prototype) - FD.finite_difference_jacobian!(J, cons, θ, jaccache) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) - end - - if cons !== nothing && f.cons_h === nothing - hess_cons_cache = [FD.HessianCache(copy(cache.u0), adtype.fdhtype) - for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons#note: colorvecs not yet supported by FiniteDiff for Hessians - FD.finite_difference_hessian!(res[i], - (x) -> (_res = zeros(eltype(θ), num_cons); - cons(_res, - x); - _res[i]), - θ, updatecache(hess_cons_cache[i], θ)) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) - end - if f.lag_h === nothing - lag_hess_cache = FD.HessianCache(copy(cache.u0), adtype.fdhtype) - c = zeros(num_cons) - h = zeros(length(cache.u0), length(cache.u0)) - lag_h = let c = c, h = h - lag = function (θ, σ, μ) - f.cons(c, θ, cache.p) - l = μ'c - if !iszero(σ) - l += σ * f.f(θ, cache.p) - end - l - end - function (res, θ, σ, μ) - FD.finite_difference_hessian!(h, - (x) -> lag(x, σ, μ), - θ, - updatecache(lag_hess_cache, θ)) - k = 1 - for i in 1:length(cache.u0), j in i:length(cache.u0) - res[k] = h[i, j] - k += 1 - end - end - end - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end - return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - cons_jac_colorvec = cons_jac_colorvec, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoFiniteDiff, p, - num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - updatecache = (cache, x) -> (cache.xmm .= x; cache.xmp .= x; cache.xpm .= x; cache.xpp .= x; return cache) - - if f.grad === nothing - gradcache = FD.GradientCache(x, x, adtype.fdtype) - grad = (θ, args...) -> FD.finite_difference_gradient(x -> _f(x, args...), - θ, gradcache) - else - grad = (θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - hesscache = FD.HessianCache(x, adtype.fdhtype) - hess = (θ, args...) -> FD.finite_difference_hessian(x -> _f(x, args...), θ, - updatecache(hesscache, θ)) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - T = eltype(θ) - ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(θ))) - @. θ += ϵ * v - cache2 = similar(θ) - grad(cache2, θ, args...) - @. θ -= 2ϵ * v - cache3 = similar(θ) - grad(cache3, θ, args...) - @. θ += ϵ * v - return @. (cache2 - cache3) / (2ϵ) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - end - - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? (1:length(x)) : - f.cons_jac_colorvec - - if cons !== nothing && f.cons_j === nothing - cons_j = function (θ) - y0 = zeros(eltype(θ), num_cons) - jaccache = FD.JacobianCache(copy(x), copy(y0), copy(y0), adtype.fdjtype; - colorvec = cons_jac_colorvec, - sparsity = f.cons_jac_prototype) - if num_cons > 1 - return FD.finite_difference_jacobian(cons, θ, jaccache) - else - return FD.finite_difference_jacobian(cons, θ, jaccache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - hess_cons_cache = [FD.HessianCache(copy(x), adtype.fdhtype) - for i in 1:num_cons] - cons_h = function (θ) - return map(1:num_cons) do i - FD.finite_difference_hessian(x -> cons(x)[i], θ, - updatecache(hess_cons_cache[i], θ)) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_hess_cache = FD.HessianCache(copy(x), adtype.fdhtype) - c = zeros(num_cons) - h = zeros(length(x), length(x)) - lag_h = let c = c, h = h - lag = function (θ, σ, μ) - f.cons(c, θ, p) - l = μ'c - if !iszero(σ) - l += σ * f.f(θ, p) - end - l - end - function (θ, σ, μ) - FD.finite_difference_hessian((x) -> lag(x, σ, μ), - θ, - updatecache(lag_hess_cache, θ)) - end - end - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - cons_jac_colorvec = cons_jac_colorvec, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoFiniteDiff, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - updatecache = (cache, x) -> (cache.xmm .= x; cache.xmp .= x; cache.xpm .= x; cache.xpp .= x; return cache) - p = cache.p - - if f.grad === nothing - gradcache = FD.GradientCache(x, x, adtype.fdtype) - grad = (θ, args...) -> FD.finite_difference_gradient(x -> _f(x, args...), - θ, gradcache) - else - grad = (θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - hesscache = FD.HessianCache(x, adtype.fdhtype) - hess = (θ, args...) -> FD.finite_difference_hessian!(x -> _f(x, args...), θ, - updatecache(hesscache, θ)) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - T = eltype(θ) - ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(θ))) - @. θ += ϵ * v - cache2 = similar(θ) - grad(cache2, θ, args...) - @. θ -= 2ϵ * v - cache3 = similar(θ) - grad(cache3, θ, args...) - @. θ += ϵ * v - return @. (cache2 - cache3) / (2ϵ) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - end - - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? (1:length(x)) : - f.cons_jac_colorvec - - if cons !== nothing && f.cons_j === nothing - cons_j = function (θ) - y0 = zeros(num_cons) - jaccache = FD.JacobianCache(copy(x), copy(y0), copy(y0), adtype.fdjtype; - colorvec = cons_jac_colorvec, - sparsity = f.cons_jac_prototype) - if num_cons > 1 - return FD.finite_difference_jacobian(cons, θ, jaccache) - else - return FD.finite_difference_jacobian(cons, θ, jaccache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - hess_cons_cache = [FD.HessianCache(copy(x), adtype.fdhtype) - for i in 1:num_cons] - cons_h = function (θ) - return map(1:num_cons) do i - FD.finite_difference_hessian(x -> cons(x)[i], θ, - updatecache(hess_cons_cache[i], θ)) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_hess_cache = FD.HessianCache(copy(x), adtype.fdhtype) - c = zeros(num_cons) - h = zeros(length(x), length(x)) - lag_h = let c = c, h = h - lag = function (θ, σ, μ) - f.cons(c, θ, p) - l = μ'c - if !iszero(σ) - l += σ * f.f(θ, p) - end - l - end - function (θ, σ, μ) - FD.finite_difference_hessian((x) -> lag(x, σ, μ), - θ, - updatecache(lag_hess_cache, θ)) - end - end - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - cons_jac_colorvec = cons_jac_colorvec, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -end diff --git a/ext/OptimizationForwardDiffExt.jl b/ext/OptimizationForwardDiffExt.jl deleted file mode 100644 index f2732c4..0000000 --- a/ext/OptimizationForwardDiffExt.jl +++ /dev/null @@ -1,341 +0,0 @@ -module OptimizationForwardDiffExt - -import OptimizationBase, OptimizationBase.ArrayInterface -import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.ADTypes: AutoForwardDiff -isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) - -function default_chunk_size(len) - if len < ForwardDiff.DEFAULT_CHUNK_THRESHOLD - len - else - ForwardDiff.DEFAULT_CHUNK_THRESHOLD - end -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoForwardDiff{_chunksize}, p, - num_cons = 0) where {_chunksize} - chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - hesscfg = ForwardDiff.HessianConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - hess = (res, θ, args...) -> ForwardDiff.hessian!(res, x -> _f(x, args...), θ, - hesscfg, Val{false}()) - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - res = ArrayInterface.zeromatrix(θ) - hess(res, θ, args...) - H .= res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - if cons !== nothing && f.cons_j === nothing - cjconfig = ForwardDiff.JacobianConfig(cons_oop, x, ForwardDiff.Chunk{chunksize}()) - cons_j = function (J, θ) - ForwardDiff.jacobian!(J, cons_oop, θ, cjconfig) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], x, - ForwardDiff.Chunk{chunksize}()) - for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - ForwardDiff.hessian!(res[i], fncs[i], θ, hess_config_cache[i], Val{true}()) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h, f.lag_hess_prototype) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoForwardDiff{_chunksize}, - num_cons = 0) where {_chunksize} - chunksize = _chunksize === nothing ? default_chunk_size(length(cache.u0)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, cache.u0, ForwardDiff.Chunk{chunksize}()) - grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - if f.hess === nothing - hesscfg = ForwardDiff.HessianConfig(_f, cache.u0, ForwardDiff.Chunk{chunksize}()) - hess = (res, θ, args...) -> (ForwardDiff.hessian!(res, x -> _f(x, args...), θ, - hesscfg, Val{false}())) - else - hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - res = ArrayInterface.zeromatrix(θ) - hess(res, θ, args...) - H .= res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, cache.p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - if cons !== nothing && f.cons_j === nothing - cjconfig = ForwardDiff.JacobianConfig(cons_oop, cache.u0, - ForwardDiff.Chunk{chunksize}()) - cons_j = function (J, θ) - ForwardDiff.jacobian!(J, cons_oop, θ, cjconfig) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], cache.u0, - ForwardDiff.Chunk{chunksize}()) - for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - ForwardDiff.hessian!(res[i], fncs[i], θ, hess_config_cache[i], Val{true}()) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end - - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoForwardDiff{_chunksize}, p, - num_cons = 0) where {_chunksize} - chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - grad = (θ, args...) -> ForwardDiff.gradient(x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - hesscfg = ForwardDiff.HessianConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - hess = (θ, args...) -> ForwardDiff.hessian(x -> _f(x, args...), θ, - hesscfg, Val{false}()) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - res = ArrayInterface.zeromatrix(θ) - hess(res, θ, args...) - return res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - if cons !== nothing && f.cons_j === nothing - cjconfig = ForwardDiff.JacobianConfig(cons_oop, x, ForwardDiff.Chunk{chunksize}()) - cons_j = function (θ) - if num_cons > 1 - return ForwardDiff.jacobian(cons_oop, θ, cjconfig) - else - return ForwardDiff.jacobian(cons_oop, θ, cjconfig)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], x, - ForwardDiff.Chunk{chunksize}()) - for i in 1:num_cons] - cons_h = function (θ) - map(1:num_cons) do i - ForwardDiff.hessian(fncs[i], θ, hess_config_cache[i], Val{true}()) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h, f.lag_hess_prototype) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoForwardDiff{_chunksize}, - num_cons = 0) where {_chunksize} - chunksize = _chunksize === nothing ? default_chunk_size(length(cache.u0)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - p = cache.p - x = cache.u0 - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - grad = (θ, args...) -> ForwardDiff.gradient(x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - hesscfg = ForwardDiff.HessianConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - hess = (θ, args...) -> ForwardDiff.hessian(x -> _f(x, args...), θ, - hesscfg, Val{false}()) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - res = ArrayInterface.zeromatrix(θ) - hess(res, θ, args...) - return res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - if cons !== nothing && f.cons_j === nothing - cjconfig = ForwardDiff.JacobianConfig(cons_oop, x, ForwardDiff.Chunk{chunksize}()) - cons_j = function (θ) - if num_cons > 1 - return ForwardDiff.jacobian(cons_oop, θ, cjconfig) - else - return ForwardDiff.jacobian(cons_oop, θ, cjconfig)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], x, - ForwardDiff.Chunk{chunksize}()) - for i in 1:num_cons] - cons_h = function (θ) - map(1:num_cons) do i - ForwardDiff.hessian(fncs[i], θ, hess_config_cache[i], Val{true}()) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -end diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index 07ead62..1b84f72 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -7,7 +7,7 @@ import OptimizationBase.ADTypes: AutoModelingToolkit, AutoSymbolics, AutoSparse isdefined(Base, :get_extension) ? (using ModelingToolkit) : (using ..ModelingToolkit) function OptimizationBase.instantiate_function( - f, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p, + f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p, num_cons = 0) where {S, C} p = isnothing(p) ? SciMLBase.NullParameters() : p @@ -52,7 +52,7 @@ function OptimizationBase.instantiate_function( observed = f.observed) end -function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache, +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::AutoSparse{<:AutoSymbolics, S, C}, num_cons = 0) where {S, C} p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p @@ -98,7 +98,7 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit observed = f.observed) end -function OptimizationBase.instantiate_function(f, x, adtype::AutoSymbolics, p, +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p, num_cons = 0) p = isnothing(p) ? SciMLBase.NullParameters() : p @@ -143,7 +143,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoSymbolics, p, observed = f.observed) end -function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache, +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::AutoSymbolics, num_cons = 0) p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p @@ -189,4 +189,4 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit observed = f.observed) end -end +end \ No newline at end of file diff --git a/ext/OptimizationReverseDiffExt.jl b/ext/OptimizationReverseDiffExt.jl deleted file mode 100644 index 58f1bf3..0000000 --- a/ext/OptimizationReverseDiffExt.jl +++ /dev/null @@ -1,581 +0,0 @@ -module OptimizationReverseDiffExt - -import OptimizationBase -import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.ADTypes: AutoReverseDiff -# using SparseDiffTools, Symbolics -isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) : -(using ..ReverseDiff, ..ReverseDiff.ForwardDiff) - -struct OptimizationReverseDiffTag end - -function default_chunk_size(len) - if len < ForwardDiff.DEFAULT_CHUNK_THRESHOLD - len - else - ForwardDiff.DEFAULT_CHUNK_THRESHOLD - end -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoReverseDiff, - p = SciMLBase.NullParameters(), - num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - chunksize = default_chunk_size(length(x)) - - if f.grad === nothing - if adtype.compile - _tape = ReverseDiff.GradientTape(_f, x) - tape = ReverseDiff.compile(_tape) - grad = function (res, θ, args...) - ReverseDiff.gradient!(res, tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(x) - grad = (res, θ, args...) -> ReverseDiff.gradient!(res, - x -> _f(x, args...), - θ, - cfg) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - if adtype.compile - T = ForwardDiff.Tag(OptimizationReverseDiffTag(), eltype(x)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(x), - chunksize - }.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(θ) - res1 = zeros(eltype(θ), length(θ)) - ReverseDiff.gradient!(res1, htape, θ) - end - jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk{chunksize}(), T) - hess = function (res, θ, args...) - ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}()) - end - else - hess = function (res, θ, args...) - ReverseDiff.hessian!(res, x -> _f(x, args...), θ) - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - res = zeros(length(θ), length(θ)) - hess(res, θ, args...) - H .= res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - if cons !== nothing && f.cons_j === nothing - if adtype.compile - _jac_tape = ReverseDiff.JacobianTape(cons_oop, x) - jac_tape = ReverseDiff.compile(_jac_tape) - cons_j = function (J, θ) - ReverseDiff.jacobian!(J, jac_tape, θ) - end - else - cjconfig = ReverseDiff.JacobianConfig(x) - cons_j = function (J, θ) - ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig) - end - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - if adtype.compile - consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual)) - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(θ, htape) - res1 = zeros(eltype(θ), length(θ)) - ReverseDiff.gradient!(res1, htape, θ) - end - gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardDiff.JacobianConfig(gs[i], - x, - ForwardDiff.Chunk{chunksize}(), - T) for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}()) - end - end - else - cons_h = function (res, θ) - for i in 1:num_cons - ReverseDiff.hessian!(res[i], fncs[i], θ) - end - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoReverseDiff, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - chunksize = default_chunk_size(length(cache.u0)) - - if f.grad === nothing - if adtype.compile - _tape = ReverseDiff.GradientTape(_f, cache.u0) - tape = ReverseDiff.compile(_tape) - grad = function (res, θ, args...) - ReverseDiff.gradient!(res, tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(cache.u0) - grad = (res, θ, args...) -> ReverseDiff.gradient!(res, - x -> _f(x, args...), - θ, - cfg) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - if f.hess === nothing - if adtype.compile - T = ForwardDiff.Tag(OptimizationReverseDiffTag(), eltype(cache.u0)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(cache.u0), - chunksize - }.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), chunksize)...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(θ) - res1 = zeros(eltype(θ), length(θ)) - ReverseDiff.gradient!(res1, htape, θ) - end - jaccfg = ForwardDiff.JacobianConfig(g, - cache.u0, - ForwardDiff.Chunk{chunksize}(), - T) - hess = function (res, θ, args...) - ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}()) - end - else - hess = function (res, θ, args...) - ReverseDiff.hessian!(res, x -> _f(x, args...), θ) - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - res = zeros(length(θ), length(θ)) - hess(res, θ, args...) - H .= res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, cache.p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - if cons !== nothing && f.cons_j === nothing - if adtype.compile - _jac_tape = ReverseDiff.JacobianTape(cons_oop, cache.u0) - jac_tape = ReverseDiff.compile(_jac_tape) - cons_j = function (J, θ) - ReverseDiff.jacobian!(J, jac_tape, θ) - end - else - cjconfig = ReverseDiff.JacobianConfig(cache.u0) - cons_j = function (J, θ) - ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig) - end - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - if adtype.compile - consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual)) - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(θ, htape) - res1 = zeros(eltype(θ), length(θ)) - ReverseDiff.gradient!(res1, htape, θ) - end - gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardDiff.JacobianConfig(gs[i], - cache.u0, - ForwardDiff.Chunk{chunksize}(), - T) for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}()) - end - end - else - cons_h = function (res, θ) - for i in 1:num_cons - ReverseDiff.hessian!(res[i], fncs[i], θ) - end - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end - - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoReverseDiff, - p = SciMLBase.NullParameters(), - num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - chunksize = default_chunk_size(length(x)) - - if f.grad === nothing - if adtype.compile - _tape = ReverseDiff.GradientTape(_f, x) - tape = ReverseDiff.compile(_tape) - grad = function (θ, args...) - ReverseDiff.gradient!(tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(x) - grad = (θ, args...) -> ReverseDiff.gradient(x -> _f(x, args...), - θ, - cfg) - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - if adtype.compile - T = ForwardDiff.Tag(OptimizationReverseDiffTag(), eltype(x)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(x), - chunksize - }.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(θ) - ReverseDiff.gradient!(htape, θ) - end - jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk{chunksize}(), T) - hess = function (θ, args...) - ForwardDiff.jacobian(g, θ, jaccfg, Val{false}()) - end - else - hess = function (θ, args...) - ReverseDiff.hessian(x -> _f(x, args...), θ) - end - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - return hess(θ, args...) * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - if cons !== nothing && f.cons_j === nothing - if adtype.compile - _jac_tape = ReverseDiff.JacobianTape(cons_oop, x) - jac_tape = ReverseDiff.compile(_jac_tape) - cons_j = function (θ) - if num_cons > 1 - ReverseDiff.jacobian!(jac_tape, θ) - else - ReverseDiff.jacobian!(jac_tape, θ)[1, :] - end - end - else - cjconfig = ReverseDiff.JacobianConfig(x) - cons_j = function (θ) - if num_cons > 1 - return ReverseDiff.jacobian(cons_oop, θ, cjconfig) - else - return ReverseDiff.jacobian(cons_oop, θ, cjconfig)[1, :] - end - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - if adtype.compile - consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual)) - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(θ, htape) - ReverseDiff.gradient!(htape, θ) - end - gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardDiff.JacobianConfig(gs[i], - x, - ForwardDiff.Chunk{chunksize}(), - T) for i in 1:num_cons] - cons_h = function (θ) - map(1:num_cons) do i - ForwardDiff.jacobian(gs[i], θ, jaccfgs[i], Val{false}()) - end - end - else - cons_h = function (θ) - map(1:num_cons) do i - ReverseDiff.hessian(fncs[i], θ) - end - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoReverseDiff, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - chunksize = default_chunk_size(length(cache.u0)) - p = cache.p - - if f.grad === nothing - if adtype.compile - _tape = ReverseDiff.GradientTape(_f, x) - tape = ReverseDiff.compile(_tape) - grad = function (θ, args...) - ReverseDiff.gradient!(tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(x) - grad = (θ, args...) -> ReverseDiff.gradient(x -> _f(x, args...), - θ, - cfg) - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - if adtype.compile - T = ForwardDiff.Tag(OptimizationReverseDiffTag(), eltype(x)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(x), - chunksize - }.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(θ) - ReverseDiff.gradient!(htape, θ) - end - jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk{chunksize}(), T) - hess = function (θ, args...) - ForwardDiff.jacobian(g, θ, jaccfg, Val{false}()) - end - else - hess = function (θ, args...) - ReverseDiff.hessian(x -> _f(x, args...), θ) - end - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - return hess(θ, args...) * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - if cons !== nothing && f.cons_j === nothing - if adtype.compile - _jac_tape = ReverseDiff.JacobianTape(cons_oop, x) - jac_tape = ReverseDiff.compile(_jac_tape) - cons_j = function (θ) - if num_cons > 1 - ReverseDiff.jacobian!(jac_tape, θ) - else - ReverseDiff.jacobian!(jac_tape, θ)[1, :] - end - end - else - cjconfig = ReverseDiff.JacobianConfig(x) - cons_j = function (θ) - if num_cons > 1 - return ReverseDiff.jacobian(cons_oop, θ, cjconfig) - else - return ReverseDiff.jacobian(cons_oop, θ, cjconfig)[1, :] - end - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - if adtype.compile - consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual)) - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(θ, htape) - ReverseDiff.gradient!(htape, θ) - end - gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardDiff.JacobianConfig(gs[i], - x, - ForwardDiff.Chunk{chunksize}(), - T) for i in 1:num_cons] - cons_h = function (θ) - map(1:num_cons) do i - ForwardDiff.jacobian(gs[i], θ, jaccfgs[i], Val{false}()) - end - end - else - cons_h = function (θ) - map(1:num_cons) do i - ReverseDiff.hessian(fncs[i], θ) - end - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -end diff --git a/ext/OptimizationSparseDiffExt.jl b/ext/OptimizationSparseDiffExt.jl deleted file mode 100644 index cc0ffd6..0000000 --- a/ext/OptimizationSparseDiffExt.jl +++ /dev/null @@ -1,31 +0,0 @@ -module OptimizationSparseDiffExt - -import OptimizationBase, OptimizationBase.ArrayInterface -import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.ADTypes: AutoSparse, AutoFiniteDiff, AutoForwardDiff, - AutoReverseDiff -using OptimizationBase.LinearAlgebra, ReverseDiff -isdefined(Base, :get_extension) ? -(using SparseDiffTools, - SparseDiffTools.ForwardDiff, SparseDiffTools.FiniteDiff, Symbolics) : -(using ..SparseDiffTools, - ..SparseDiffTools.ForwardDiff, ..SparseDiffTools.FiniteDiff, ..Symbolics) - -function default_chunk_size(len) - if len < ForwardDiff.DEFAULT_CHUNK_THRESHOLD - len - else - ForwardDiff.DEFAULT_CHUNK_THRESHOLD - end -end - -include("OptimizationSparseForwardDiff.jl") - -const FD = FiniteDiff - -include("OptimizationSparseFiniteDiff.jl") - -struct OptimizationSparseReverseTag end - -include("OptimizationSparseReverseDiff.jl") -end diff --git a/ext/OptimizationSparseFiniteDiff.jl b/ext/OptimizationSparseFiniteDiff.jl deleted file mode 100644 index 686d614..0000000 --- a/ext/OptimizationSparseFiniteDiff.jl +++ /dev/null @@ -1,541 +0,0 @@ - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoSparse{<:AutoFiniteDiff, S, C}, p, - num_cons = 0) where {S, C} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcache = FD.GradientCache(x, x) - grad = (res, θ, args...) -> FD.finite_difference_gradient!( - res, x -> _f(x, args...), - θ, gradcache) - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = isnothing(f.hess_prototype) ? Symbolics.hessian_sparsity(_f, x) : - f.hess_prototype - hess_colors = matrix_colors(hess_sparsity) - hess = (res, θ, args...) -> numauto_color_hessian!(res, x -> _f(x, args...), θ, - ForwardColorHesCache(_f, x, - hess_colors, - hess_sparsity, - (res, θ) -> grad(res, - θ, - args...))) - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - num_hesvec!(H, x -> _f(x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = f.cons_jac_prototype === nothing ? - Symbolics.jacobian_sparsity(cons, - zeros(eltype(x), num_cons), - x) : - f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? - matrix_colors(cons_jac_prototype) : - f.cons_jac_colorvec - cons_j = function (J, θ) - y0 = zeros(num_cons) - jaccache = FD.JacobianCache(copy(x), copy(y0), copy(y0); - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype) - FD.finite_difference_jacobian!(J, cons, θ, jaccache) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - conshess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x, i) - conshess_sparsity = isnothing(f.cons_hess_prototype) ? - copy(Symbolics.hessian_sparsity(_f, x)) : - f.cons_hess_prototype[i] - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, conshess_sparsity) - return hesscache - end - - fcons = [(x) -> (_res = zeros(eltype(x), num_cons); - cons(_res, x); - _res[i]) for i in 1:num_cons] - conshess_caches = [gen_conshess_cache(fcons[i], x, i) for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - numauto_color_hessian!(res[i], fcons[i], θ, conshess_caches[i]) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - # lag_hess_cache = FD.HessianCache(copy(x)) - # c = zeros(num_cons) - # h = zeros(length(x), length(x)) - # lag_h = let c = c, h = h - # lag = function (θ, σ, μ) - # f.cons(c, θ, p) - # l = μ'c - # if !iszero(σ) - # l += σ * f.f(θ, p) - # end - # l - # end - # function (res, θ, σ, μ) - # FD.finite_difference_hessian!(res, - # (x) -> lag(x, σ, μ), - # θ, - # updatecache(lag_hess_cache, θ)) - # end - # end - lag_h = nothing - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = getfield.(conshess_caches, :sparsity), - cons_hess_colorvec = getfield.(conshess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoFiniteDiff, S, C}, num_cons = 0) where {S, C} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - x = cache.u0 - p = cache.p - - if f.grad === nothing - gradcache = FD.GradientCache(cache.u0, cache.u0) - grad = (res, θ, args...) -> FD.finite_difference_gradient!( - res, x -> _f(x, args...), - θ, gradcache) - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = isnothing(f.hess_prototype) ? Symbolics.hessian_sparsity(_f, x) : - f.hess_prototype - hess_colors = matrix_colors(hess_sparsity) - hess = (res, θ, args...) -> numauto_color_hessian!(res, x -> _f(x, args...), θ, - ForwardColorHesCache(_f, x, - hess_colors, - hess_sparsity, - (res, θ) -> grad(res, - θ, - args...))) - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - num_hesvec!(H, x -> _f(x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = f.cons_jac_prototype === nothing ? - Symbolics.jacobian_sparsity(cons, - zeros(eltype(x), num_cons), - x) : - f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? - matrix_colors(cons_jac_prototype) : - f.cons_jac_colorvec - cons_j = function (J, θ) - y0 = zeros(num_cons) - jaccache = FD.JacobianCache(copy(x), copy(y0), copy(y0); - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype) - FD.finite_difference_jacobian!(J, cons, θ, jaccache) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - conshess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x, i) - conshess_sparsity = isnothing(f.cons_hess_prototype) ? - copy(Symbolics.hessian_sparsity(_f, x)) : - f.cons_hess_prototype[i] - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, conshess_sparsity) - return hesscache - end - - fcons = [(x) -> (_res = zeros(eltype(x), num_cons); - cons(_res, x); - _res[i]) for i in 1:num_cons] - conshess_caches = [gen_conshess_cache(fcons[i], x, i) for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - numauto_color_hessian!(res[i], fcons[i], θ, conshess_caches[i]) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - # lag_hess_cache = FD.HessianCache(copy(cache.u0)) - # c = zeros(num_cons) - # h = zeros(length(cache.u0), length(cache.u0)) - # lag_h = let c = c, h = h - # lag = function (θ, σ, μ) - # f.cons(c, θ, cache.p) - # l = μ'c - # if !iszero(σ) - # l += σ * f.f(θ, cache.p) - # end - # l - # end - # function (res, θ, σ, μ) - # FD.finite_difference_hessian!(h, - # (x) -> lag(x, σ, μ), - # θ, - # updatecache(lag_hess_cache, θ)) - # k = 1 - # for i in 1:length(cache.u0), j in i:length(cache.u0) - # res[k] = h[i, j] - # k += 1 - # end - # end - # end - lag_h = nothing - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end - return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = getfield.(conshess_caches, :sparsity), - cons_hess_colorvec = getfield.(conshess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoSparse{<:AutoFiniteDiff, S, C}, p, - num_cons = 0) where {S, C} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcache = FD.GradientCache(x, x) - grad = (θ, args...) -> FD.finite_difference_gradient(x -> _f(x, args...), - θ, gradcache) - else - grad = (θ, args...) -> f.grad(θ, cache.p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = matrix_colors(tril(hess_sparsity)) - hess = (θ, args...) -> numauto_color_hessian(x -> _f(x, args...), θ, - ForwardColorHesCache(_f, θ, - hess_colors, - hess_sparsity, - (res, θ) -> (res .= grad(θ, - args...)))) - else - hess = (θ, args...) -> f.hess(θ, cache.p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - return num_hesvec(x -> _f(x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = f.cons_jac_prototype === nothing ? - Symbolics.jacobian_sparsity((res, x) -> (res .= cons(x)), - zeros(eltype(x), num_cons), - x) : - f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? - matrix_colors(cons_jac_prototype) : - f.cons_jac_colorvec - cons_j = function (θ) - y0 = zeros(eltype(θ), num_cons) - jaccache = FD.JacobianCache(copy(θ), copy(y0), copy(y0); - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype) - if num_cons > 1 - return FD.finite_difference_jacobian(cons, θ, jaccache) - else - return FD.finite_difference_jacobian(cons, θ, jaccache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - conshess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x) - conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x)) - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, - conshess_sparsity) - return hesscache - end - - fcons = [(x) -> cons(x)[i] for i in 1:num_cons] - conshess_caches = [gen_conshess_cache(fcons[i], x) for i in 1:num_cons] - cons_h = function (θ) - map(1:num_cons) do i - numauto_color_hessian(fcons[i], θ, conshess_caches[i]) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - if f.lag_h === nothing - # lag_hess_cache = FD.HessianCache(copy(cache.u0)) - # c = zeros(num_cons) - # h = zeros(length(cache.u0), length(cache.u0)) - # lag_h = let c = c, h = h - # lag = function (θ, σ, μ) - # f.cons(c, θ, cache.p) - # l = μ'c - # if !iszero(σ) - # l += σ * f.f(θ, cache.p) - # end - # l - # end - # function (res, θ, σ, μ) - # FD.finite_difference_hessian!(h, - # (x) -> lag(x, σ, μ), - # θ, - # updatecache(lag_hess_cache, θ)) - # k = 1 - # for i in 1:length(cache.u0), j in i:length(cache.u0) - # res[k] = h[i, j] - # k += 1 - # end - # end - # end - lag_h = nothing - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = getfield.(conshess_caches, :sparsity), - cons_hess_colorvec = getfield.(conshess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoFiniteDiff, S, C}, num_cons = 0) where {S, C} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcache = FD.GradientCache(x, x) - grad = (θ, args...) -> FD.finite_difference_gradient(x -> _f(x, args...), - θ, gradcache) - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = matrix_colors(tril(hess_sparsity)) - hess = (θ, args...) -> numauto_color_hessian(x -> _f(x, args...), θ, - ForwardColorHesCache(_f, θ, - hess_colors, - hess_sparsity, - (res, θ) -> (res .= grad(θ, - args...)))) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - return num_hesvec(x -> _f(x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = f.cons_jac_prototype === nothing ? - Symbolics.jacobian_sparsity((res, x) -> (res .= cons(x)), - zeros(eltype(x), num_cons), - x) : - f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec === nothing ? - matrix_colors(cons_jac_prototype) : - f.cons_jac_colorvec - cons_j = function (θ) - y0 = zeros(eltype(θ), num_cons) - jaccache = FD.JacobianCache(copy(θ), copy(y0), copy(y0); - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype) - return FD.finite_difference_jacobian(cons, θ, jaccache) - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - conshess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x) - conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x)) - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, - conshess_sparsity) - return hesscache - end - - fcons = [(x) -> cons(x)[i] for i in 1:num_cons] - conshess_caches = [gen_conshess_cache(fcons[i], x) for i in 1:num_cons] - cons_h = function (θ) - map(1:num_cons) do i - numauto_color_hessian(fcons[i], θ, conshess_caches[i]) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - if f.lag_h === nothing - # lag_hess_cache = FD.HessianCache(copy(x)) - # c = zeros(num_cons) - # h = zeros(length(x), length(x)) - # lag_h = let c = c, h = h - # lag = function (θ, σ, μ) - # f.cons(c, θ, p) - # l = μ'c - # if !iszero(σ) - # l += σ * f.f(θ, p) - # end - # l - # end - # function (res, θ, σ, μ) - # FD.finite_difference_hessian!(h, - # (x) -> lag(x, σ, μ), - # θ, - # updatecache(lag_hess_cache, θ)) - # k = 1 - # for i in 1:length(x), j in i:length(x) - # res[k] = h[i, j] - # k += 1 - # end - # end - # end - lag_h = nothing - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = getfield.(conshess_caches, :sparsity), - cons_hess_colorvec = getfield.(conshess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end diff --git a/ext/OptimizationSparseForwardDiff.jl b/ext/OptimizationSparseForwardDiff.jl deleted file mode 100644 index bb8de67..0000000 --- a/ext/OptimizationSparseForwardDiff.jl +++ /dev/null @@ -1,459 +0,0 @@ -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoSparse{<:AutoForwardDiff{_chunksize}}, p, - num_cons = 0) where {_chunksize} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = isnothing(f.hess_prototype) ? Symbolics.hessian_sparsity(_f, x) : - f.hess_prototype - hess_colors = matrix_colors(hess_sparsity) - hess = (res, θ, args...) -> numauto_color_hessian!(res, x -> _f(x, args...), θ, - ForwardColorHesCache(_f, x, - hess_colors, - hess_sparsity, - (res, θ) -> grad(res, - θ, - args...))) - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - num_hesvecgrad!(H, (res, x) -> grad(res, x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = isnothing(f.cons_jac_prototype) ? - Symbolics.jacobian_sparsity(cons, zeros(eltype(x), num_cons), - x) : f.cons_jac_prototype - cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = ForwardColorJacCache(cons, - x, - chunksize; - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype, - dx = zeros(eltype(x), num_cons)) - cons_j = function (J, θ) - forwarddiff_color_jacobian!(J, cons, θ, jaccache) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - cons_hess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x, i) - conshess_sparsity = isnothing(f.cons_hess_prototype) ? - copy(Symbolics.hessian_sparsity(_f, x)) : - f.cons_hess_prototype[i] - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, - conshess_sparsity) - return hesscache - end - - fcons = [(x) -> (_res = zeros(eltype(x), num_cons); - cons(_res, x); - _res[i]) for i in 1:num_cons] - cons_hess_caches = [gen_conshess_cache(fcons[i], x, i) for i in 1:num_cons] - cons_h = function (res, θ) - fetch.([Threads.@spawn numauto_color_hessian!( - res[i], fcons[i], θ, cons_hess_caches[i]) for i in 1:num_cons]) - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_colorvec = cons_jac_colorvec, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = getfield.(cons_hess_caches, :sparsity), - cons_hess_colorvec = getfield.(cons_hess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoForwardDiff{_chunksize}}, - num_cons = 0) where {_chunksize} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - chunksize = _chunksize === nothing ? default_chunk_size(length(cache.u0)) : _chunksize - - x = cache.u0 - p = cache.p - - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, cache.u0, ForwardDiff.Chunk{chunksize}()) - grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = isnothing(f.hess_prototype) ? Symbolics.hessian_sparsity(_f, x) : - f.hess_prototype - hess_colors = matrix_colors(hess_sparsity) - hess = (res, θ, args...) -> numauto_color_hessian!(res, x -> _f(x, args...), θ, - ForwardColorHesCache(_f, x, - hess_colors, - hess_sparsity, - (res, θ) -> grad(res, - θ, - args...))) - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - num_hesvecgrad!(H, (res, x) -> grad(res, x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - cons_jac_prototype = isnothing(f.cons_jac_prototype) ? - Symbolics.jacobian_sparsity(cons, zeros(eltype(x), num_cons), - x) : f.cons_jac_prototype - cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = ForwardColorJacCache(cons, - x, - chunksize; - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype, - dx = zeros(eltype(x), num_cons)) - cons_j = function (J, θ) - forwarddiff_color_jacobian!(J, cons, θ, jaccache) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - cons_hess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x, i) - conshess_sparsity = isnothing(f.cons_hess_prototype) ? - copy(Symbolics.hessian_sparsity(_f, x)) : - f.cons_hess_prototype[i] - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, - conshess_sparsity) - return hesscache - end - - fcons = [(x) -> (_res = zeros(eltype(x), num_cons); - cons(_res, x); - _res[i]) for i in 1:num_cons] - cons_hess_caches = [gen_conshess_cache(fcons[i], x, i) for i in 1:num_cons] - cons_h = function (res, θ) - fetch.([Threads.@spawn numauto_color_hessian!( - res[i], fcons[i], θ, cons_hess_caches[i]) for i in 1:num_cons]) - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end - - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = getfield.(cons_hess_caches, :sparsity), - cons_hess_colorvec = getfield.(cons_hess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoSparse{<:AutoForwardDiff{_chunksize}}, p, - num_cons = 0) where {_chunksize} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - chunksize = _chunksize === nothing ? default_chunk_size(length(x)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - grad = (θ, args...) -> ForwardDiff.gradient(x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = matrix_colors(tril(hess_sparsity)) - hess = (θ, args...) -> numauto_color_hessian(x -> _f(x, args...), θ, - ForwardColorHesCache(_f, x, - hess_colors, - hess_sparsity, - (G, θ) -> (G .= grad(θ, - args...)))) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - num_hesvecgrad((x) -> grad(x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - res = zeros(eltype(x), num_cons) - cons_jac_prototype = Symbolics.jacobian_sparsity((res, x) -> (res .= cons(x)), - res, - x) - cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = ForwardColorJacCache(cons, - x, - chunksize; - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype) - cons_j = function (θ) - if num_cons > 1 - return forwarddiff_color_jacobian(cons, θ, jaccache) - else - return forwarddiff_color_jacobian(cons, θ, jaccache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - cons_hess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x) - conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x)) - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, - conshess_sparsity) - return hesscache - end - - fcons = [(x) -> cons(x)[i] for i in 1:num_cons] - cons_hess_caches = gen_conshess_cache.(fcons, Ref(x)) - cons_h = function (θ) - map(1:num_cons) do i - numauto_color_hessian(fcons[i], θ, cons_hess_caches[i]) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_colorvec = cons_jac_colorvec, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = getfield.(cons_hess_caches, :sparsity), - cons_hess_colorvec = getfield.(cons_hess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoForwardDiff{_chunksize}}, - num_cons = 0) where {_chunksize} - if maximum(getfield.(methods(f.f), :nargs)) > 3 - error("$(string(adtype)) with SparseDiffTools does not support functions with more than 2 arguments") - end - chunksize = _chunksize === nothing ? default_chunk_size(length(cache.u0)) : _chunksize - - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - p = cache.p - - if f.grad === nothing - gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}()) - grad = (θ, args...) -> ForwardDiff.gradient(x -> _f(x, args...), θ, - gradcfg, Val{false}()) - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = matrix_colors(tril(hess_sparsity)) - hess = (θ, args...) -> numauto_color_hessian(x -> _f(x, args...), θ, - ForwardColorHesCache(_f, x, - hess_colors, - hess_sparsity, - (G, θ) -> (G .= grad(θ, - args...)))) - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - num_hesvecgrad((x) -> grad(res, x, args...), θ, v) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - res = zeros(eltype(x), num_cons) - cons_jac_prototype = Symbolics.jacobian_sparsity((res, x) -> (res .= cons(x)), - res, - x) - cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = ForwardColorJacCache(cons, - x, - chunksize; - colorvec = cons_jac_colorvec, - sparsity = cons_jac_prototype) - cons_j = function (θ) - if num_cons > 1 - return forwarddiff_color_jacobian(cons, θ, jaccache) - else - return forwarddiff_color_jacobian(cons, θ, jaccache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - cons_hess_caches = [(; sparsity = f.cons_hess_prototype, colors = f.cons_hess_colorvec)] - if cons !== nothing && f.cons_h === nothing - function gen_conshess_cache(_f, x) - conshess_sparsity = copy(Symbolics.hessian_sparsity(_f, x)) - conshess_colors = matrix_colors(conshess_sparsity) - hesscache = ForwardColorHesCache(_f, x, conshess_colors, - conshess_sparsity) - return hesscache - end - - fcons = [(x) -> cons(x)[i] for i in 1:num_cons] - cons_hess_caches = gen_conshess_cache.(fcons, Ref(x)) - cons_h = function (θ) - map(1:num_cons) do i - numauto_color_hessian(fcons[i], θ, cons_hess_caches[i]) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_colorvec = cons_jac_colorvec, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = getfield.(cons_hess_caches, :sparsity), - cons_hess_colorvec = getfield.(cons_hess_caches, :colors), - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end diff --git a/ext/OptimizationSparseReverseDiff.jl b/ext/OptimizationSparseReverseDiff.jl deleted file mode 100644 index ac15215..0000000 --- a/ext/OptimizationSparseReverseDiff.jl +++ /dev/null @@ -1,751 +0,0 @@ -function OptimizationBase.ADTypes.AutoSparseReverseDiff(compile::Bool) - return AutoSparse(AutoReverseDiff(; compile)) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoSparse{<:AutoReverseDiff}, - p = SciMLBase.NullParameters(), - num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - chunksize = default_chunk_size(length(x)) - - if f.grad === nothing - if adtype.dense_ad.compile - _tape = ReverseDiff.GradientTape(_f, x) - tape = ReverseDiff.compile(_tape) - grad = function (res, θ, args...) - ReverseDiff.gradient!(res, tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(x) - grad = (res, θ, args...) -> ReverseDiff.gradient!(res, - x -> _f(x, args...), - θ, - cfg) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity)) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(x)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(x), - min(chunksize, maximum(hess_colors)) - }.(x, - Ref(ForwardDiff.Partials((ones(eltype(x), - min(chunksize, maximum(hess_colors)))...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(res1, θ) - ReverseDiff.gradient!(res1, htape, θ) - end - jaccfg = ForwardColorJacCache(g, - x; - tag = typeof(T), - colorvec = hess_colors, - sparsity = hess_sparsity) - hess = function (res, θ, args...) - SparseDiffTools.forwarddiff_color_jacobian!(res, g, θ, jaccfg) - end - else - hess = function (res, θ, args...) - res .= SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = hess_colors, - sparsity = hess_sparsity) do θ - ReverseDiff.gradient(x -> _f(x, args...), θ) - end - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - res = zeros(length(θ), length(θ)) - hess(res, θ, args...) - H .= res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), - SparseDiffTools.SymbolicsSparsityDetection(), - cons_oop, - x, - fx = zeros(eltype(x), num_cons)) - # let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons - # ForwardColorJacCache(cons, θ; - # colorvec = cons_jac_colorvec, - # sparsity = cons_jac_prototype, - # dx = zeros(eltype(θ), num_cons)) - # end - cons_jac_prototype = jaccache.jac_prototype - cons_jac_colorvec = jaccache.coloring - cons_j = function (J, θ, args...; cons = cons, cache = jaccache.cache) - forwarddiff_color_jacobian!(J, cons, θ, cache) - return - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(x)) - conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(x)) - xduals = [ForwardDiff.Dual{ - typeof(T), - eltype(x), - min(chunksize, maximum(conshess_colors[i])) - }.(x, - Ref(ForwardDiff.Partials((ones(eltype(x), - min(chunksize, maximum(conshess_colors[i])))...,)))) - for i in 1:num_cons] - consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons] - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(res1, θ, htape) - ReverseDiff.gradient!(res1, htape, θ) - end - gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardColorJacCache(gs[i], - x; - tag = typeof(T), - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) for i in 1:num_cons] - cons_h = function (res, θ, args...) - for i in 1:num_cons - SparseDiffTools.forwarddiff_color_jacobian!(res[i], - gs[i], - θ, - jaccfgs[i]) - end - end - else - cons_h = function (res, θ) - for i in 1:num_cons - res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) do θ - ReverseDiff.gradient(fncs[i], θ) - end - end - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoReverseDiff}, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - chunksize = default_chunk_size(length(cache.u0)) - - if f.grad === nothing - if adtype.dense_ad.compile - _tape = ReverseDiff.GradientTape(_f, cache.u0) - tape = ReverseDiff.compile(_tape) - grad = function (res, θ, args...) - ReverseDiff.gradient!(res, tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(cache.u0) - grad = (res, θ, args...) -> ReverseDiff.gradient!(res, - x -> _f(x, args...), - θ, - cfg) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0) - hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity)) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(cache.u0)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(cache.u0), - min(chunksize, maximum(hess_colors)) - }.(cache.u0, - Ref(ForwardDiff.Partials((ones(eltype(cache.u0), - min(chunksize, maximum(hess_colors)))...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(res1, θ) - ReverseDiff.gradient!(res1, htape, θ) - end - jaccfg = ForwardColorJacCache(g, - cache.u0; - tag = typeof(T), - colorvec = hess_colors, - sparsity = hess_sparsity) - hess = function (res, θ, args...) - SparseDiffTools.forwarddiff_color_jacobian!(res, g, θ, jaccfg) - end - else - hess = function (res, θ, args...) - res .= SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = hess_colors, - sparsity = hess_sparsity) do θ - ReverseDiff.gradient(x -> _f(x, args...), θ) - end - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - res = zeros(length(θ), length(θ)) - hess(res, θ, args...) - H .= res * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = function (res, θ) - f.cons(res, θ, cache.p) - return - end - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - # cons_jac_prototype = Symbolics.jacobian_sparsity(cons, - # zeros(eltype(cache.u0), num_cons), - # cache.u0) - # cons_jac_colorvec = matrix_colors(cons_jac_prototype) - jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), - SparseDiffTools.SymbolicsSparsityDetection(), - cons_oop, - cache.u0, - fx = zeros(eltype(cache.u0), num_cons)) - # let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons - # ForwardColorJacCache(cons, θ; - # colorvec = cons_jac_colorvec, - # sparsity = cons_jac_prototype, - # dx = zeros(eltype(θ), num_cons)) - # end - cons_jac_prototype = jaccache.jac_prototype - cons_jac_colorvec = jaccache.coloring - cons_j = function (J, θ) - forwarddiff_color_jacobian!(J, cons, θ, jaccache.cache) - return - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = map(1:num_cons) do i - function (x) - res = zeros(eltype(x), num_cons) - f.cons(res, x, cache.p) - return res[i] - end - end - conshess_sparsity = map(1:num_cons) do i - let fnc = fncs[i], θ = cache.u0 - Symbolics.hessian_sparsity(fnc, θ) - end - end - conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(cache.u0)) - xduals = [ForwardDiff.Dual{ - typeof(T), - eltype(cache.u0), - min(chunksize, maximum(conshess_colors[i])) - }.(cache.u0, - Ref(ForwardDiff.Partials((ones(eltype(cache.u0), - min(chunksize, maximum(conshess_colors[i])))...,)))) - for i in 1:num_cons] - consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons] - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(res1, θ, htape) - ReverseDiff.gradient!(res1, htape, θ) - end - gs = let conshtapes = conshtapes - map(1:num_cons) do i - function (res1, x) - grad_cons(res1, x, conshtapes[i]) - end - end - end - jaccfgs = [ForwardColorJacCache(gs[i], - cache.u0; - tag = typeof(T), - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - SparseDiffTools.forwarddiff_color_jacobian!(res[i], - gs[i], - θ, - jaccfgs[i]) - end - end - else - cons_h = function (res, θ) - for i in 1:num_cons - res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) do θ - ReverseDiff.gradient(fncs[i], θ) - end - end - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end - - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoSparse{<:AutoReverseDiff}, - p = SciMLBase.NullParameters(), - num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - chunksize = default_chunk_size(length(x)) - - if f.grad === nothing - if adtype.dense_ad.compile - _tape = ReverseDiff.GradientTape(_f, x) - tape = ReverseDiff.compile(_tape) - grad = function (θ, args...) - ReverseDiff.gradient!(tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(x) - grad = (θ, args...) -> ReverseDiff.gradient(x -> _f(x, args...), - θ, - cfg) - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity)) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(x)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(x), - min(chunksize, maximum(hess_colors)) - }.(x, - Ref(ForwardDiff.Partials((ones(eltype(x), - min(chunksize, maximum(hess_colors)))...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(θ) - ReverseDiff.gradient!(htape, θ) - end - jaccfg = ForwardColorJacCache(g, - x; - tag = typeof(T), - colorvec = hess_colors, - sparsity = hess_sparsity) - hess = function (θ, args...) - return SparseDiffTools.forwarddiff_color_jacobian(g, θ, jaccfg) - end - else - hess = function (θ, args...) - return SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = hess_colors, - sparsity = hess_sparsity) do θ - ReverseDiff.gradient(x -> _f(x, args...), θ) - end - end - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - res = zeros(length(θ), length(θ)) - hess(θ, args...) * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), - SparseDiffTools.SymbolicsSparsityDetection(), - cons_oop, - x, - fx = zeros(eltype(x), num_cons)) - # let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons - # ForwardColorJacCache(cons, θ; - # colorvec = cons_jac_colorvec, - # sparsity = cons_jac_prototype, - # dx = zeros(eltype(θ), num_cons)) - # end - cons_jac_prototype = jaccache.jac_prototype - cons_jac_colorvec = jaccache.coloring - cons_j = function (θ, args...; cons = cons, cache = jaccache.cache) - if num_cons > 1 - return forwarddiff_color_jacobian(cons, θ, cache) - else - return forwarddiff_color_jacobian(cons, θ, cache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(x)) - conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(x)) - xduals = [ForwardDiff.Dual{ - typeof(T), - eltype(x), - min(chunksize, maximum(conshess_colors[i])) - }.(x, - Ref(ForwardDiff.Partials((ones(eltype(x), - min(chunksize, maximum(conshess_colors[i])))...,)))) - for i in 1:num_cons] - consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons] - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(θ, htape) - ReverseDiff.gradient!(htape, θ) - end - gs = [(x) -> grad_cons(x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardColorJacCache(gs[i], - x; - tag = typeof(T), - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) for i in 1:num_cons] - cons_h = function (θ, args...) - map(1:num_cons) do i - SparseDiffTools.forwarddiff_color_jacobian(gs[i], - θ, - jaccfgs[i]) - end - end - else - cons_h = function (θ) - map(1:num_cons) do i - SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) do θ - ReverseDiff.gradient(fncs[i], θ) - end - end - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoReverseDiff}, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - chunksize = default_chunk_size(length(cache.u0)) - p = cache.p - x = cache.u0 - - if f.grad === nothing - if adtype.dense_ad.compile - _tape = ReverseDiff.GradientTape(_f, x) - tape = ReverseDiff.compile(_tape) - grad = function (θ, args...) - ReverseDiff.gradient!(tape, θ) - end - else - cfg = ReverseDiff.GradientConfig(x) - grad = (θ, args...) -> ReverseDiff.gradient(x -> _f(x, args...), - θ, - cfg) - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - hess_sparsity = Symbolics.hessian_sparsity(_f, x) - hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity)) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(x)) - xdual = ForwardDiff.Dual{ - typeof(T), - eltype(x), - min(chunksize, maximum(hess_colors)) - }.(x, - Ref(ForwardDiff.Partials((ones(eltype(x), - min(chunksize, maximum(hess_colors)))...,)))) - h_tape = ReverseDiff.GradientTape(_f, xdual) - htape = ReverseDiff.compile(h_tape) - function g(θ) - ReverseDiff.gradient!(htape, θ) - end - jaccfg = ForwardColorJacCache(g, - x; - tag = typeof(T), - colorvec = hess_colors, - sparsity = hess_sparsity) - hess = function (θ, args...) - return SparseDiffTools.forwarddiff_color_jacobian(g, θ, jaccfg) - end - else - hess = function (θ, args...) - return SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = hess_colors, - sparsity = hess_sparsity) do θ - ReverseDiff.gradient(x -> _f(x, args...), θ) - end - end - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - # _θ = ForwardDiff.Dual.(θ, v) - # res = similar(_θ) - # grad(res, _θ, args...) - # H .= getindex.(ForwardDiff.partials.(res), 1) - res = zeros(length(θ), length(θ)) - hess(θ, args...) * v - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), - SparseDiffTools.SymbolicsSparsityDetection(), - cons_oop, - x, - fx = zeros(eltype(x), num_cons)) - # let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons - # ForwardColorJacCache(cons, θ; - # colorvec = cons_jac_colorvec, - # sparsity = cons_jac_prototype, - # dx = zeros(eltype(θ), num_cons)) - # end - cons_jac_prototype = jaccache.jac_prototype - cons_jac_colorvec = jaccache.coloring - cons_j = function (θ, args...; cons = cons, cache = jaccache.cache) - if num_cons > 1 - return forwarddiff_color_jacobian(cons, θ, cache) - else - return forwarddiff_color_jacobian(cons, θ, cache)[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(x)) - conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity) - if adtype.dense_ad.compile - T = ForwardDiff.Tag(OptimizationSparseReverseTag(), eltype(x)) - xduals = [ForwardDiff.Dual{ - typeof(T), - eltype(x), - min(chunksize, maximum(conshess_colors[i])) - }.(x, - Ref(ForwardDiff.Partials((ones(eltype(x), - min(chunksize, maximum(conshess_colors[i])))...,)))) - for i in 1:num_cons] - consh_tapes = [ReverseDiff.GradientTape(fncs[i], xduals[i]) for i in 1:num_cons] - conshtapes = ReverseDiff.compile.(consh_tapes) - function grad_cons(θ, htape) - ReverseDiff.gradient!(htape, θ) - end - gs = [(x) -> grad_cons(x, conshtapes[i]) for i in 1:num_cons] - jaccfgs = [ForwardColorJacCache(gs[i], - x; - tag = typeof(T), - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) for i in 1:num_cons] - cons_h = function (θ, args...) - map(1:num_cons) do i - SparseDiffTools.forwarddiff_color_jacobian(gs[i], - θ, - jaccfgs[i]) - end - end - else - cons_h = function (θ) - map(1:num_cons) do i - SparseDiffTools.forwarddiff_color_jacobian(θ, - colorvec = conshess_colors[i], - sparsity = conshess_sparsity[i]) do θ - ReverseDiff.gradient(fncs[i], θ) - end - end - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end diff --git a/ext/OptimizationTrackerExt.jl b/ext/OptimizationTrackerExt.jl deleted file mode 100644 index f315d54..0000000 --- a/ext/OptimizationTrackerExt.jl +++ /dev/null @@ -1,72 +0,0 @@ -module OptimizationTrackerExt - -import OptimizationBase -import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.ADTypes: AutoTracker -isdefined(Base, :get_extension) ? (using Tracker) : (using ..Tracker) - -function OptimizationBase.instantiate_function(f, x, adtype::AutoTracker, p, - num_cons = 0) - num_cons != 0 && error("AutoTracker does not currently support constraints") - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if f.grad === nothing - grad = (res, θ, args...) -> res .= Tracker.data(Tracker.gradient( - x -> _f(x, args...), - θ)[1]) - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - hess = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg") - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` and `hv` kwargs") - else - hv = f.hv - end - - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = nothing, cons_j = nothing, cons_h = nothing, - hess_prototype = f.hess_prototype, - cons_jac_prototype = nothing, - cons_hess_prototype = nothing) -end - -function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache, - adtype::AutoTracker, num_cons = 0) - num_cons != 0 && error("AutoTracker does not currently support constraints") - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) - - if f.grad === nothing - grad = (res, θ, args...) -> res .= Tracker.data(Tracker.gradient( - x -> _f(x, args...), - θ)[1]) - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - if f.hess === nothing - hess = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg") - else - hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...) - end - - if f.hv === nothing - hv = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` and `hv` kwargs") - else - hv = f.hv - end - - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = nothing, cons_j = nothing, cons_h = nothing, - hess_prototype = f.hess_prototype, - cons_jac_prototype = nothing, - cons_hess_prototype = nothing) -end - -end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 867472a..fe348ca 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -342,4 +342,4 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cons_expr = f.cons_expr) end -end +end \ No newline at end of file From e771c6c27808eb9f02aa8cb0520058a57de503bf Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 31 May 2024 16:30:15 -0400 Subject: [PATCH 02/33] DI to latest --- Project.toml | 1 + test/Project.toml | 2 ++ test/adtests.jl | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0dbcfb2..5aec7a9 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ OptimizationZygoteExt = ["Zygote", "DifferentiationInterface"] [compat] ADTypes = "1.3" ArrayInterface = "7.6" +DifferentiationInterface = "0.5.2" DocStringExtensions = "0.9" LinearAlgebra = "1.9, 1.10" ModelingToolkit = "9" diff --git a/test/Project.toml b/test/Project.toml index 8cc5a79..6ad1f79 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -30,6 +31,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" ComponentArrays = ">= 0.13.9" +DifferentiationInterface = "0.5.2" DiffEqFlux = ">= 2" Flux = "0.13, 0.14" IterTools = ">= 1.3.0" diff --git a/test/adtests.jl b/test/adtests.jl index 2558eb6..4dc2e00 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -1,4 +1,4 @@ -using OptimizationBase, Test, SparseArrays, Symbolics +using OptimizationBase, Test, DifferentiationInterface,SparseArrays, Symbolics using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker using ModelingToolkit, Enzyme, Random From 3b41754d04f85fc266c2c59539afc995c357a7fe Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 6 Jun 2024 19:55:29 -0400 Subject: [PATCH 03/33] Flesh out latest ideas --- Project.toml | 6 +- ext/OptimizationDIExt.jl | 8 +- ext/OptimizationDISparseExt.jl | 380 +++++++++++++++++++++++++++++++++ ext/OptimizationMTKExt.jl | 2 +- ext/OptimizationZygoteExt.jl | 3 +- 5 files changed, 394 insertions(+), 5 deletions(-) create mode 100644 ext/OptimizationDISparseExt.jl diff --git a/Project.toml b/Project.toml index 5aec7a9..fc524c2 100644 --- a/Project.toml +++ b/Project.toml @@ -18,14 +18,18 @@ SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [weakdeps] DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -OptimizationDIExt = "DifferentiationInterface" +OptimizationDIExt = ["DifferentiationInterface", "ForwardDiff", "ReverseDiff"] OptimizationMTKExt = "ModelingToolkit" OptimizationZygoteExt = ["Zygote", "DifferentiationInterface"] diff --git a/ext/OptimizationDIExt.jl b/ext/OptimizationDIExt.jl index 186327d..a8c1c42 100644 --- a/ext/OptimizationDIExt.jl +++ b/ext/OptimizationDIExt.jl @@ -7,10 +7,16 @@ import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian using ADTypes +import ForwardDiff, ReverseDiff function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if adtype isa ADTypes.ForwardMode + soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) + elseif adtype isa ADTypes.ReverseMode + soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + end if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) diff --git a/ext/OptimizationDISparseExt.jl b/ext/OptimizationDISparseExt.jl new file mode 100644 index 0000000..6cff16b --- /dev/null +++ b/ext/OptimizationDISparseExt.jl @@ -0,0 +1,380 @@ +module OptimizationDIExt + +import OptimizationBase, OptimizationBase.ArrayInterface +import OptimizationBase.SciMLBase: OptimizationFunction +import OptimizationBase.LinearAlgebra: I +import DifferentiationInterface +import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, + gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian +using ADTypes +using SparseConnectivityTracer, SparseMatrixColorings + +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse, p = SciMLBase.NullParameters(), num_cons = 0) + _f = (θ, args...) -> first(f.f(θ, p, args...)) + + if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) + elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa AbstractADTypes.NoColoringAlgorithm) + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) + end + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(res, θ) + gradient!(_f, res, adtype, θ, extras_grad) + end + else + grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(res, θ, args...) + hessian!(_f, res, adtype, θ, extras_hess) + end + else + hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (H, θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp!(_f, H, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (J, θ) + jacobian!(cons_oop, J, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(H, θ) + for i in 1:num_cons + hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) + x = cache.u0 + p = cache.p + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(res, θ) + gradient!(_f, res, adtype, θ, extras_grad) + end + else + grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) + function hess(res, θ, args...) + hessian!(_f, res, soadtype, θ, extras_hess) + end + else + hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (H, θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp!(_f, H, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (J, θ) + jacobian!(cons_oop, J, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(H, θ) + for i in 1:num_cons + hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + + +function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(θ) + gradient(_f, adtype, θ, extras_grad) + end + else + grad = (θ, args...) -> f.grad(θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(θ, args...) + hessian(_f, adtype, θ, extras_hess) + end + else + hess = (θ, args...) -> f.hess(θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp(_f, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (θ) -> f.cons(θ, p) + cons_oop = cons + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (θ) + J = jacobian(cons_oop, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + return J + end + else + cons_j = (θ) -> f.cons_j(θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(θ) + H = map(1:num_cons) do i + hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + end + return H + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + +function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) + x = cache.u0 + p = cache.p + _f = (θ, args...) -> first(f.f(θ, p, args...)) + soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(θ) + gradient(_f, adtype, θ, extras_grad) + end + else + grad = (θ, args...) -> f.grad(θ, p, args...) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(θ, args...) + hessian(_f, soadtype, θ, extras_hess) + end + else + hess = (θ, args...) -> f.hess(θ, p, args...) + end + + if f.hv === nothing + extras_hvp = nothing + hv = function (θ, v, args...) + if extras_hvp === nothing + global extras_hvp = prepare_hvp(_f, soadtype, x, v) + end + hvp(_f, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + cons = (θ) -> f.cons(θ, p) + cons_oop = cons + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + cons_j = function (θ) + J = jacobian(cons_oop, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + return J + end + else + cons_j = (θ) -> f.cons_j(θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(θ) + H = map(1:num_cons) do i + hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + end + return H + end + else + cons_h = (θ) -> f.cons_h(θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, f.lag_hess_prototype) +end + +end \ No newline at end of file diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index 1b84f72..716fd87 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -4,7 +4,7 @@ import OptimizationBase, OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.ADTypes: AutoModelingToolkit, AutoSymbolics, AutoSparse -isdefined(Base, :get_extension) ? (using ModelingToolkit) : (using ..ModelingToolkit) +using ModelingToolkit function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p, diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index fe348ca..b72bc71 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -3,8 +3,7 @@ module OptimizationZygoteExt import OptimizationBase import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.ADTypes: AutoZygote -isdefined(Base, :get_extension) ? (using Zygote, Zygote.ForwardDiff) : -(using ..Zygote, ..Zygote.ForwardDiff) +using Zygote, Zygote.ForwardDiff function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoZygote, p, From 3951c88aa264f8277acef9a3b6681b85e574c938 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 6 Jun 2024 22:41:09 -0400 Subject: [PATCH 04/33] EnzymeExt and fix mode checking --- Project.toml | 1 + ext/OptimizationDIExt.jl | 6 +++--- test/adtests.jl | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index fc524c2..7f3a805 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] OptimizationDIExt = ["DifferentiationInterface", "ForwardDiff", "ReverseDiff"] +OptimizationEnzymeExt = "Enzyme" OptimizationMTKExt = "ModelingToolkit" OptimizationZygoteExt = ["Zygote", "DifferentiationInterface"] diff --git a/ext/OptimizationDIExt.jl b/ext/OptimizationDIExt.jl index a8c1c42..51d42a3 100644 --- a/ext/OptimizationDIExt.jl +++ b/ext/OptimizationDIExt.jl @@ -12,9 +12,9 @@ import ForwardDiff, ReverseDiff function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) - if adtype isa ADTypes.ForwardMode + if ADTypes.mode(adtype) isa ADTypes.ForwardMode soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) - elseif adtype isa ADTypes.ReverseMode + elseif ADTypes.mode(adtype) isa ADTypes.ReverseMode soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) end @@ -30,7 +30,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing - extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + extras_hess = prepare_hessian(_f, soadtype, x) function hess(res, θ, args...) hessian!(_f, res, adtype, θ, extras_hess) end diff --git a/test/adtests.jl b/test/adtests.jl index 4dc2e00..cb2f67b 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -1,4 +1,4 @@ -using OptimizationBase, Test, DifferentiationInterface,SparseArrays, Symbolics +using OptimizationBase, Test, DifferentiationInterface, SparseArrays, Symbolics using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker using ModelingToolkit, Enzyme, Random From ad7ca08f8f8de5cc1363b97ce1602368f5ef4920 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 8 Jul 2024 22:31:34 -0400 Subject: [PATCH 05/33] Use the changes from main in enzyme --- ext/OptimizationEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 54e1140..c0f6e1a 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -535,4 +535,4 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cons_hess_prototype = f.cons_hess_prototype) end -end +end \ No newline at end of file From 4fe1f54f313c587d71114621fb68e13307ae4d15 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 15 Jul 2024 15:42:19 -0400 Subject: [PATCH 06/33] Flesh out sparsity and secondorder things more, change the extension structure --- Project.toml | 11 ++-- ext/OptimizationFiniteDiffExt.jl | 5 ++ ext/OptimizationForwardDiffExt.jl | 5 ++ ext/OptimizationMTKExt.jl | 6 +- ext/OptimizationReverseDiffExt.jl | 5 ++ src/OptimizationBase.jl | 2 + {ext => src}/OptimizationDIExt.jl | 49 ++++++++++------- {ext => src}/OptimizationDISparseExt.jl | 73 +++++++++++++++++-------- src/function.jl | 59 +------------------- test/adtests.jl | 10 ++-- 10 files changed, 112 insertions(+), 113 deletions(-) create mode 100644 ext/OptimizationFiniteDiffExt.jl create mode 100644 ext/OptimizationForwardDiffExt.jl create mode 100644 ext/OptimizationReverseDiffExt.jl rename {ext => src}/OptimizationDIExt.jl (86%) rename {ext => src}/OptimizationDISparseExt.jl (76%) diff --git a/Project.toml b/Project.toml index 7f3a805..97bec92 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "1.3.3" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -21,18 +22,20 @@ SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [weakdeps] -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -OptimizationDIExt = ["DifferentiationInterface", "ForwardDiff", "ReverseDiff"] +OptimizationForwardDiffExt = "ForwardDiff" +OptimizationFiniteDiffExt = "FiniteDiff" +OptimizationReverseDiffExt = "ReverseDiff" OptimizationEnzymeExt = "Enzyme" OptimizationMTKExt = "ModelingToolkit" -OptimizationZygoteExt = ["Zygote", "DifferentiationInterface"] +OptimizationZygoteExt = "Zygote" [compat] ADTypes = "1.3" @@ -44,11 +47,9 @@ ModelingToolkit = "9" Reexport = "1.2" Requires = "1" SciMLBase = "2" -SparseDiffTools = "2.14" SymbolicAnalysis = "0.1, 0.2" SymbolicIndexingInterface = "0.3" Symbolics = "5.12" -Tracker = "0.2.29" Zygote = "0.6.67" julia = "1.10" diff --git a/ext/OptimizationFiniteDiffExt.jl b/ext/OptimizationFiniteDiffExt.jl new file mode 100644 index 0000000..b0a95a6 --- /dev/null +++ b/ext/OptimizationFiniteDiffExt.jl @@ -0,0 +1,5 @@ +module OptimizationFiniteDiffExt + +using DifferentiationInterface, FiniteDiff + +end \ No newline at end of file diff --git a/ext/OptimizationForwardDiffExt.jl b/ext/OptimizationForwardDiffExt.jl new file mode 100644 index 0000000..0da99fb --- /dev/null +++ b/ext/OptimizationForwardDiffExt.jl @@ -0,0 +1,5 @@ +module OptimizationForwardDiffExt + +using DifferentiationInterface, ForwardDiff + +end \ No newline at end of file diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index 716fd87..08ebef9 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -7,8 +7,8 @@ import OptimizationBase.ADTypes: AutoModelingToolkit, AutoSymbolics, AutoSparse using ModelingToolkit function OptimizationBase.instantiate_function( - f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p, - num_cons = 0) where {S, C} + f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics}, p, + num_cons = 0) p = isnothing(p) ? SciMLBase.NullParameters() : p sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p; @@ -53,7 +53,7 @@ function OptimizationBase.instantiate_function( end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoSymbolics, S, C}, num_cons = 0) where {S, C} + adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0) p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, cache.u0, diff --git a/ext/OptimizationReverseDiffExt.jl b/ext/OptimizationReverseDiffExt.jl new file mode 100644 index 0000000..23ec82d --- /dev/null +++ b/ext/OptimizationReverseDiffExt.jl @@ -0,0 +1,5 @@ +module OptimizationReverseDiffExt + +using DifferentiationInterface, ReverseDiff + +end \ No newline at end of file diff --git a/src/OptimizationBase.jl b/src/OptimizationBase.jl index 33f738d..ee9486a 100644 --- a/src/OptimizationBase.jl +++ b/src/OptimizationBase.jl @@ -32,6 +32,8 @@ Base.length(::NullData) = 0 include("adtypes.jl") include("cache.jl") include("function.jl") +include("OptimizationDIExt.jl") +include("OptimizationDISparseExt.jl") export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA diff --git a/ext/OptimizationDIExt.jl b/src/OptimizationDIExt.jl similarity index 86% rename from ext/OptimizationDIExt.jl rename to src/OptimizationDIExt.jl index 51d42a3..66bd88d 100644 --- a/ext/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -1,20 +1,18 @@ -module OptimizationDIExt - -import OptimizationBase, OptimizationBase.ArrayInterface +using OptimizationBase +import OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian -using ADTypes -import ForwardDiff, ReverseDiff +using ADTypes, SciMLBase function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) - if ADTypes.mode(adtype) isa ADTypes.ForwardMode - soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) - elseif ADTypes.mode(adtype) isa ADTypes.ReverseMode + if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode + soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) end @@ -32,7 +30,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) function hess(res, θ, args...) - hessian!(_f, res, adtype, θ, extras_hess) + hessian!(_f, res, soadtype, θ, extras_hess) end else hess = (H, θ, args...) -> f.hess(H, θ, p, args...) @@ -79,7 +77,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, function cons_h(H, θ) for i in 1:num_cons - hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end else @@ -106,7 +104,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode + soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + end if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -169,7 +172,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca function cons_h(H, θ) for i in 1:num_cons - hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end else @@ -195,7 +198,12 @@ end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode + soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + end if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -211,7 +219,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(θ, args...) - hessian(_f, adtype, θ, extras_hess) + hessian(_f, soadtype, θ, extras_hess) end else hess = (θ, args...) -> f.hess(θ, p, args...) @@ -259,7 +267,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x function cons_h(θ) H = map(1:num_cons) do i - hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) end return H end @@ -287,7 +295,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode + soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + end if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -351,7 +364,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c function cons_h(θ) H = map(1:num_cons) do i - hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) end return H end @@ -374,5 +387,3 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c cons_hess_colorvec = conshess_colors, lag_h, f.lag_hess_prototype) end - -end \ No newline at end of file diff --git a/ext/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl similarity index 76% rename from ext/OptimizationDISparseExt.jl rename to src/OptimizationDISparseExt.jl index 6cff16b..0b22f7c 100644 --- a/ext/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -1,6 +1,5 @@ -module OptimizationDIExt - -import OptimizationBase, OptimizationBase.ArrayInterface +using OptimizationBase +import OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface @@ -9,21 +8,48 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, using ADTypes using SparseConnectivityTracer, SparseMatrixColorings -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse, p = SciMLBase.NullParameters(), num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - +function generate_sparse_adtype(adtype) if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) - elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa AbstractADTypes.NoColoringAlgorithm) + if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) + end + elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + end elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) + if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) + end + else + if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) + end end + return adtype,soadtype +end + + +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) + _f = (θ, args...) -> first(f.f(θ, p, args...)) + + adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(res, θ) - gradient!(_f, res, adtype, θ, extras_grad) + gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) @@ -34,7 +60,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(res, θ, args...) - hessian!(_f, res, adtype, θ, extras_hess) + hessian!(_f, res, soadtype, θ, extras_hess) end else hess = (H, θ, args...) -> f.hess(H, θ, p, args...) @@ -81,7 +107,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, function cons_h(H, θ) for i in 1:num_cons - hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end else @@ -104,11 +130,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) +function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -171,7 +198,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca function cons_h(H, θ) for i in 1:num_cons - hessian!(fncs[i], H[i], adtype, θ, extras_cons_hess[i]) + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end else @@ -195,9 +222,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca end -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) +function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -213,7 +241,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(θ, args...) - hessian(_f, adtype, θ, extras_hess) + hessian(_f, soadtype, θ, extras_hess) end else hess = (θ, args...) -> f.hess(θ, p, args...) @@ -261,7 +289,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x function cons_h(θ) H = map(1:num_cons) do i - hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) end return H end @@ -285,11 +313,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) +function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) - soadtype = DifferentiationInterface.SecondOrder(adtype, adtype) + + adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -353,7 +382,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c function cons_h(θ) H = map(1:num_cons) do i - hessian(fncs[i], adtype, θ, extras_cons_hess[i]) + hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) end return H end @@ -376,5 +405,3 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c cons_hess_colorvec = conshess_colors, lag_h, f.lag_hess_prototype) end - -end \ No newline at end of file diff --git a/src/function.jl b/src/function.jl index 8f9dc46..53ab75e 100644 --- a/src/function.jl +++ b/src/function.jl @@ -43,65 +43,8 @@ function that is not defined, an error is thrown. For more information on the use of automatic differentiation, see the documentation of the `AbstractADType` types. """ -function instantiate_function(f, x, ::SciMLBase.NoAD, - p, num_cons = 0) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) - cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) - cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) - cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) - hess_prototype = f.hess_prototype === nothing ? nothing : - convert.(eltype(x), f.hess_prototype) - cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : - convert.(eltype(x), f.cons_jac_prototype) - cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : - [convert.(eltype(x), f.cons_hess_prototype[i]) - for i in 1:num_cons] - expr = symbolify(f.expr) - cons_expr = symbolify.(f.cons_expr) - - return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess, - hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_prototype, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = cons_hess_prototype, - expr = expr, cons_expr = cons_expr, - sys = f.sys, - observed = f.observed) -end - -function instantiate_function(f, cache::ReInitCache, ::SciMLBase.NoAD, - num_cons = 0) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...) - cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p) - cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p) - cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p) - hess_prototype = f.hess_prototype === nothing ? nothing : - convert.(eltype(cache.u0), f.hess_prototype) - cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : - convert.(eltype(cache.u0), f.cons_jac_prototype) - cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : - [convert.(eltype(cache.u0), f.cons_hess_prototype[i]) - for i in 1:num_cons] - expr = symbolify(f.expr) - cons_expr = symbolify.(f.cons_expr) - - return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess, - hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_prototype, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = cons_hess_prototype, - expr = expr, cons_expr = cons_expr, - sys = f.sys, - observed = f.observed) -end -function instantiate_function(f, x, adtype::ADTypes.AbstractADType, +function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType, p, num_cons = 0) adtypestr = string(adtype) _strtind = findfirst('.', adtypestr) diff --git a/test/adtests.jl b/test/adtests.jl index cb2f67b..9caff7b 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -232,7 +232,7 @@ optprob = OptimizationBase.instantiate_function(optf, nothing) optprob.grad(G2, x0) @test G1 == G2 -@test_throws ErrorException optprob.hess(H2, x0) +@test_broken optprob.hess(H2, x0) prob = OptimizationProblem(optf, x0) @@ -270,8 +270,8 @@ optprob.cons_h(H3, x0) H4 = Array{Float64}(undef, 2, 2) μ = randn(1) σ = rand() -optprob.lag_h(H4, x0, σ, μ) -@test H4≈σ * H1 + μ[1] * H3[1] rtol=1e-6 +# optprob.lag_h(H4, x0, σ, μ) +# @test H4≈σ * H1 + μ[1] * H3[1] rtol=1e-6 cons_jac_proto = Float64.(sparse([1 1])) # Things break if you only use [1 1]; see FiniteDiff.jl cons_jac_colors = 1:2 @@ -361,10 +361,10 @@ optprob2.hess(sH, [5.0, 3.0]) optprob2.cons_j(sJ, [5.0, 3.0]) @test all(isapprox(sJ, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) optprob2.cons_h(sH3, [5.0, 3.0]) -@test sH3 ≈ [ +@test Array.(sH3) ≈ [ [2.0 0.0; 0.0 2.0], [2.8767727327346804 0.2836621681849162; 0.2836621681849162 -6.622738308376736e-9] -] +] rtol = 1e-4 using SparseDiffTools From f0b8401980c4d8f53089a44f8f03ebdde929cccd Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 15 Jul 2024 15:53:50 -0400 Subject: [PATCH 07/33] bump adtypes --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 97bec92..c0ffd7d 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ OptimizationMTKExt = "ModelingToolkit" OptimizationZygoteExt = "Zygote" [compat] -ADTypes = "1.3" +ADTypes = "1.5" ArrayInterface = "7.6" DifferentiationInterface = "0.5.2" DocStringExtensions = "0.9" From 48ad2f46bcf73adde2e3f1a7fb1fc71396f72a21 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 15 Jul 2024 19:00:49 -0400 Subject: [PATCH 08/33] Use only the dense_ad in gradient and misc ups --- src/OptimizationDISparseExt.jl | 20 ++++++++++---------- test/adtests.jl | 18 +++++++++--------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 0b22f7c..d11e98d 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -13,28 +13,28 @@ function generate_sparse_adtype(adtype) adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) end elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) end elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) end else if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) end end return adtype,soadtype @@ -228,9 +228,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(θ) - gradient(_f, adtype, θ, extras_grad) + gradient(_f, adtype.dense_ad, θ, extras_grad) end else grad = (θ, args...) -> f.grad(θ, p, args...) diff --git a/test/adtests.jl b/test/adtests.jl index 9caff7b..51c1dcb 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -387,7 +387,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -420,7 +420,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -453,7 +453,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -477,7 +477,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -675,7 +675,7 @@ optprob.hess(H2, x0) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @test optprob.cons(x0) == [0.0, 0.0] - @test optprob.cons_j([5.0, 3.0])≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 + @test Array(optprob.cons_j([5.0, 3.0]))≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 @test Array.(optprob.cons_h(x0)) ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] cons = (x, p) -> [x[1]^2 + x[2]^2] @@ -686,7 +686,7 @@ optprob.hess(H2, x0) OptimizationBase.AutoSparseFiniteDiff(), nothing, 1) - @test optprob.grad(x0) ≈ G1 + @test optprob.grad(x0) ≈ G1 rtol=1e-4 @test Array(optprob.hess(x0)) ≈ H1 @test optprob.cons(x0) == [0.0] @@ -706,7 +706,7 @@ optprob.hess(H2, x0) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @test optprob.cons(x0) == [0.0, 0.0] - @test optprob.cons_j([5.0, 3.0])≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 + @test Array(optprob.cons_j([5.0, 3.0]))≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 @test Array.(optprob.cons_h(x0)) ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] cons = (x, p) -> [x[1]^2 + x[2]^2] @@ -737,7 +737,7 @@ optprob.hess(H2, x0) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @test optprob.cons(x0) == [0.0, 0.0] - @test optprob.cons_j([5.0, 3.0])≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 + @test Array(optprob.cons_j([5.0, 3.0]))≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 @test Array.(optprob.cons_h(x0)) ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] cons = (x, p) -> [x[1]^2 + x[2]^2] @@ -767,7 +767,7 @@ optprob.hess(H2, x0) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @test optprob.cons(x0) == [0.0, 0.0] - @test optprob.cons_j([5.0, 3.0])≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 + @test Array(optprob.cons_j([5.0, 3.0]))≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 @test Array.(optprob.cons_h(x0)) ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] cons = (x, p) -> [x[1]^2 + x[2]^2] From b9479d0621fcaf46fce84d03cec0c12107691430 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 16 Jul 2024 14:26:02 -0400 Subject: [PATCH 09/33] Use global sparsity detector --- src/OptimizationDISparseExt.jl | 14 +++++++------- test/adtests.jl | 2 -- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index d11e98d..7b9439d 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -10,18 +10,18 @@ using SparseConnectivityTracer, SparseMatrixColorings function generate_sparse_adtype(adtype) if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) end elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerLocalSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) end elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) @@ -224,7 +224,7 @@ end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) - + adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing diff --git a/test/adtests.jl b/test/adtests.jl index 51c1dcb..711f65c 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -366,8 +366,6 @@ optprob2.cons_h(sH3, [5.0, 3.0]) [2.8767727327346804 0.2836621681849162; 0.2836621681849162 -6.622738308376736e-9] ] rtol = 1e-4 -using SparseDiffTools - optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoSparseFiniteDiff(), cons = con2_c) From 517c0610df27d00fd0283a64553cde6d1fdb9603 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 16 Jul 2024 14:31:36 -0400 Subject: [PATCH 10/33] format --- ext/OptimizationEnzymeExt.jl | 2 +- ext/OptimizationFiniteDiffExt.jl | 2 +- ext/OptimizationForwardDiffExt.jl | 2 +- ext/OptimizationMTKExt.jl | 11 +-- ext/OptimizationReverseDiffExt.jl | 2 +- ext/OptimizationZygoteExt.jl | 2 +- src/OptimizationDIExt.jl | 23 +++++-- src/OptimizationDISparseExt.jl | 108 +++++++++++++++++++++--------- test/adtests.jl | 14 ++-- 9 files changed, 112 insertions(+), 54 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index c0f6e1a..54e1140 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -535,4 +535,4 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cons_hess_prototype = f.cons_hess_prototype) end -end \ No newline at end of file +end diff --git a/ext/OptimizationFiniteDiffExt.jl b/ext/OptimizationFiniteDiffExt.jl index b0a95a6..ed95f2a 100644 --- a/ext/OptimizationFiniteDiffExt.jl +++ b/ext/OptimizationFiniteDiffExt.jl @@ -2,4 +2,4 @@ module OptimizationFiniteDiffExt using DifferentiationInterface, FiniteDiff -end \ No newline at end of file +end diff --git a/ext/OptimizationForwardDiffExt.jl b/ext/OptimizationForwardDiffExt.jl index 0da99fb..0ff3e5f 100644 --- a/ext/OptimizationForwardDiffExt.jl +++ b/ext/OptimizationForwardDiffExt.jl @@ -2,4 +2,4 @@ module OptimizationForwardDiffExt using DifferentiationInterface, ForwardDiff -end \ No newline at end of file +end diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index 08ebef9..7bdda9a 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -52,7 +52,8 @@ function OptimizationBase.instantiate_function( observed = f.observed) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0) p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p @@ -98,7 +99,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca observed = f.observed) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p, +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p, num_cons = 0) p = isnothing(p) ? SciMLBase.NullParameters() : p @@ -143,7 +145,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, observed = f.observed) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::AutoSymbolics, num_cons = 0) p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p @@ -189,4 +192,4 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca observed = f.observed) end -end \ No newline at end of file +end diff --git a/ext/OptimizationReverseDiffExt.jl b/ext/OptimizationReverseDiffExt.jl index 23ec82d..11e57cf 100644 --- a/ext/OptimizationReverseDiffExt.jl +++ b/ext/OptimizationReverseDiffExt.jl @@ -2,4 +2,4 @@ module OptimizationReverseDiffExt using DifferentiationInterface, ReverseDiff -end \ No newline at end of file +end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index b72bc71..211db2d 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -341,4 +341,4 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cons_expr = f.cons_expr) end -end \ No newline at end of file +end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 66bd88d..2ac030b 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -3,11 +3,15 @@ import OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface -import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, - gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian +import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, + prepare_jacobian, + gradient!, hessian!, hvp!, jacobian!, gradient, hessian, + hvp, jacobian using ADTypes, SciMLBase -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, + p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode @@ -100,7 +104,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AbstractADType, num_cons = 0) x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) @@ -195,8 +201,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca lag_h, f.lag_hess_prototype) end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, + p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode @@ -291,7 +298,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AbstractADType, num_cons = 0) x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 7b9439d..2f2adda 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -3,45 +3,86 @@ import OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface -import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, - gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian +import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, + prepare_jacobian, + gradient!, hessian!, hvp!, jacobian!, gradient, hessian, + hvp, jacobian using ADTypes using SparseConnectivityTracer, SparseMatrixColorings function generate_sparse_adtype(adtype) - if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) - if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) + if adtype.sparsity_detector isa ADTypes.NoSparsityDetector && + adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) + if !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) #make zygote? + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) end - elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) - if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) + elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector && + !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + if !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) end - elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm - adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) - if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) && + adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm + adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + if !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) end else - if !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode - soadtype = AutoSparse(DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) + if !(adtype.dense_ad isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype isa SciMLBase.NoAD) && + ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = adtype.coloring_algorithm) end end - return adtype,soadtype + return adtype, soadtype end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, + p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) adtype, soadtype = generate_sparse_adtype(adtype) @@ -130,7 +171,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) @@ -221,8 +264,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca lag_h, f.lag_hess_prototype) end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, + p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) adtype, soadtype = generate_sparse_adtype(adtype) @@ -313,7 +357,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) +function OptimizationBase.instantiate_function( + f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 p = cache.p _f = (θ, args...) -> first(f.f(θ, p, args...)) diff --git a/test/adtests.jl b/test/adtests.jl index 711f65c..d0fb1ec 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -361,10 +361,10 @@ optprob2.hess(sH, [5.0, 3.0]) optprob2.cons_j(sJ, [5.0, 3.0]) @test all(isapprox(sJ, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) optprob2.cons_h(sH3, [5.0, 3.0]) -@test Array.(sH3) ≈ [ +@test Array.(sH3)≈[ [2.0 0.0; 0.0 2.0], [2.8767727327346804 0.2836621681849162; 0.2836621681849162 -6.622738308376736e-9] -] rtol = 1e-4 +] rtol=1e-4 optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoSparseFiniteDiff(), @@ -385,7 +385,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 +@test J≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -418,7 +418,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 +@test J≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -451,7 +451,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 +@test J≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -475,7 +475,7 @@ optprob.cons(res, [1.0, 2.0]) @test res ≈ [5.0, 0.682941969615793] J = Array{Float64}(undef, 2, 2) optprob.cons_j(J, [5.0, 3.0]) -@test J ≈ [10.0 6.0; -0.149013 -0.958924] rtol=1e-3 +@test J≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-3 H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] @@ -684,7 +684,7 @@ optprob.hess(H2, x0) OptimizationBase.AutoSparseFiniteDiff(), nothing, 1) - @test optprob.grad(x0) ≈ G1 rtol=1e-4 + @test optprob.grad(x0)≈G1 rtol=1e-4 @test Array(optprob.hess(x0)) ≈ H1 @test optprob.cons(x0) == [0.0] From 7beebae9beb2b77f4bf7fe5617016b8f512d0bff Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 17 Jul 2024 18:36:55 -0400 Subject: [PATCH 11/33] handle reinitcache dispatches better --- Project.toml | 6 +++--- ext/OptimizationEnzymeExt.jl | 4 ++-- src/OptimizationDIExt.jl | 8 ++++++++ src/OptimizationDISparseExt.jl | 8 ++++---- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index c0ffd7d..7e7f5ef 100644 --- a/Project.toml +++ b/Project.toml @@ -30,11 +30,11 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -OptimizationForwardDiffExt = "ForwardDiff" -OptimizationFiniteDiffExt = "FiniteDiff" -OptimizationReverseDiffExt = "ReverseDiff" OptimizationEnzymeExt = "Enzyme" +OptimizationFiniteDiffExt = "FiniteDiff" +OptimizationForwardDiffExt = "ForwardDiff" OptimizationMTKExt = "ModelingToolkit" +OptimizationReverseDiffExt = "ReverseDiff" OptimizationZygoteExt = "Zygote" [compat] diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 54e1140..4439c8b 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -136,7 +136,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end Enzyme.make_zero!(y) Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1] + BatchDuplicated(θ, seeds), Const(p), Const.(args)...) for i in 1:length(θ) if J isa Vector J[i] = Jaccache[i][1] @@ -257,7 +257,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, end Enzyme.make_zero!(y) Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1] + BatchDuplicated(θ, seeds), Const(p), Const.(args)...) for i in 1:length(θ) if J isa Vector J[i] = Jaccache[i][1] diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 2ac030b..c820355 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -18,6 +18,8 @@ function OptimizationBase.instantiate_function( soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + else + soadtype = adtype end if f.grad === nothing @@ -115,6 +117,8 @@ function OptimizationBase.instantiate_function( soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + else + soadtype = adtype end if f.grad === nothing @@ -210,6 +214,8 @@ function OptimizationBase.instantiate_function( soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + else + soadtype = adtype end if f.grad === nothing @@ -309,6 +315,8 @@ function OptimizationBase.instantiate_function( soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) + else + soadtype = adtype end if f.grad === nothing diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 2f2adda..9775a84 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -181,9 +181,9 @@ function OptimizationBase.instantiate_function( adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(res, θ) - gradient!(_f, res, adtype, θ, extras_grad) + gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) @@ -367,9 +367,9 @@ function OptimizationBase.instantiate_function( adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(θ) - gradient(_f, adtype, θ, extras_grad) + gradient(_f, adtype.dense_ad, θ, extras_grad) end else grad = (θ, args...) -> f.grad(θ, p, args...) From 27cd0d466fcd139671278369c6f404038c2003e9 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 19 Jul 2024 13:23:34 -0400 Subject: [PATCH 12/33] try to get downstream running --- .github/workflows/Downstream.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index cfd1eb2..80a7f26 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -17,7 +17,7 @@ jobs: julia-version: [1] os: [ubuntu-latest] package: - - {user: SciML, repo: Optimization.jl, group: Optimization} + - {user: SciML, repo: Optimization.jl, group: All} steps: - uses: actions/checkout@v4 From 35bd067fe66fd94c90ab98db5e7fde8524421399 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 19 Jul 2024 14:02:44 -0400 Subject: [PATCH 13/33] hvpextras preparation with random v --- src/OptimizationDIExt.jl | 20 ++++---------------- src/OptimizationDISparseExt.jl | 20 ++++---------------- 2 files changed, 8 insertions(+), 32 deletions(-) diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index c820355..01fda62 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -43,11 +43,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (H, θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp!(_f, H, soadtype, θ, v, extras_hvp) end else @@ -142,11 +139,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (H, θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp!(_f, H, soadtype, θ, v, extras_hvp) end else @@ -239,11 +233,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp(_f, soadtype, θ, v, extras_hvp) end else @@ -340,11 +331,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp(_f, soadtype, θ, v, extras_hvp) end else diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 9775a84..7acbb89 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -108,11 +108,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (H, θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp!(_f, H, soadtype, θ, v, extras_hvp) end else @@ -201,11 +198,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (H, θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp!(_f, H, soadtype, θ, v, extras_hvp) end else @@ -292,11 +286,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp(_f, soadtype, θ, v, extras_hvp) end else @@ -387,11 +378,8 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = nothing + extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) hv = function (θ, v, args...) - if extras_hvp === nothing - global extras_hvp = prepare_hvp(_f, soadtype, x, v) - end hvp(_f, soadtype, θ, v, extras_hvp) end else From c1a5e1f498df63de1ef875e4d9ab320003aad672 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 19 Jul 2024 15:40:45 -0400 Subject: [PATCH 14/33] fix sparse adtype passed to hvp --- src/OptimizationBase.jl | 2 +- src/OptimizationDIExt.jl | 8 ++--- src/OptimizationDISparseExt.jl | 52 +++++++++++++++++++++---------- src/function.jl | 57 ++++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 21 deletions(-) diff --git a/src/OptimizationBase.jl b/src/OptimizationBase.jl index ee9486a..f4cc0be 100644 --- a/src/OptimizationBase.jl +++ b/src/OptimizationBase.jl @@ -31,9 +31,9 @@ Base.length(::NullData) = 0 include("adtypes.jl") include("cache.jl") -include("function.jl") include("OptimizationDIExt.jl") include("OptimizationDISparseExt.jl") +include("function.jl") export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 01fda62..e185d75 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -9,7 +9,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, hvp, jacobian using ADTypes, SciMLBase -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) @@ -103,7 +103,7 @@ function OptimizationBase.instantiate_function( lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) x = cache.u0 @@ -199,7 +199,7 @@ function OptimizationBase.instantiate_function( lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) @@ -295,7 +295,7 @@ function OptimizationBase.instantiate_function( lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AbstractADType, num_cons = 0) x = cache.u0 diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 7acbb89..375a903 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -15,7 +15,12 @@ function generate_sparse_adtype(adtype) adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) - if !(adtype.dense_ad isa SciMLBase.NoAD) && + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), @@ -32,7 +37,12 @@ function generate_sparse_adtype(adtype) !(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm) adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) - if !(adtype.dense_ad isa SciMLBase.NoAD) && + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = TracerSparsityDetector(), + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), @@ -49,7 +59,12 @@ function generate_sparse_adtype(adtype) adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) - if !(adtype.dense_ad isa SciMLBase.NoAD) && + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = GreedyColoringAlgorithm()) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), @@ -63,7 +78,12 @@ function generate_sparse_adtype(adtype) coloring_algorithm = GreedyColoringAlgorithm()) end else - if !(adtype.dense_ad isa SciMLBase.NoAD) && + if adtype.dense_ad isa ADTypes.AutoFiniteDiff + soadtype = AutoSparse( + DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad), + sparsity_detector = adtype.sparsity_detector, + coloring_algorithm = adtype.coloring_algorithm) + elseif !(adtype.dense_ad isa SciMLBase.NoAD) && ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), @@ -80,7 +100,7 @@ function generate_sparse_adtype(adtype) return adtype, soadtype end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) @@ -108,9 +128,9 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) hv = function (H, θ, v, args...) - hvp!(_f, H, soadtype, θ, v, extras_hvp) + hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end else hv = f.hv @@ -168,7 +188,7 @@ function OptimizationBase.instantiate_function( lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 @@ -198,9 +218,9 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) hv = function (H, θ, v, args...) - hvp!(_f, H, soadtype, θ, v, extras_hvp) + hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end else hv = f.hv @@ -258,7 +278,7 @@ function OptimizationBase.instantiate_function( lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) _f = (θ, args...) -> first(f.f(θ, p, args...)) @@ -286,9 +306,9 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) hv = function (θ, v, args...) - hvp(_f, soadtype, θ, v, extras_hvp) + hvp(_f, soadtype.dense_ad, θ, v, extras_hvp) end else hv = f.hv @@ -348,7 +368,7 @@ function OptimizationBase.instantiate_function( lag_h, f.lag_hess_prototype) end -function OptimizationBase.instantiate_function( +function instantiate_function( f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 @@ -378,9 +398,9 @@ function OptimizationBase.instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) hv = function (θ, v, args...) - hvp(_f, soadtype, θ, v, extras_hvp) + hvp(_f, soadtype.dense_ad, θ, v, extras_hvp) end else hv = f.hv diff --git a/src/function.jl b/src/function.jl index 53ab75e..257d680 100644 --- a/src/function.jl +++ b/src/function.jl @@ -43,6 +43,63 @@ function that is not defined, an error is thrown. For more information on the use of automatic differentiation, see the documentation of the `AbstractADType` types. """ +function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, + p, num_cons = 0) + grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) + hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) + hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) + cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) + cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) + hess_prototype = f.hess_prototype === nothing ? nothing : + convert.(eltype(x), f.hess_prototype) + cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : + convert.(eltype(x), f.cons_jac_prototype) + cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : + [convert.(eltype(x), f.cons_hess_prototype[i]) + for i in 1:num_cons] + expr = symbolify(f.expr) + cons_expr = symbolify.(f.cons_expr) + + return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess, + hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_prototype, + cons_jac_prototype = cons_jac_prototype, + cons_hess_prototype = cons_hess_prototype, + expr = expr, cons_expr = cons_expr, + sys = f.sys, + observed = f.observed) +end + +function instantiate_function(f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD, + num_cons = 0) + grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...) + hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...) + hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...) + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p) + cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p) + cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p) + hess_prototype = f.hess_prototype === nothing ? nothing : + convert.(eltype(cache.u0), f.hess_prototype) + cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : + convert.(eltype(cache.u0), f.cons_jac_prototype) + cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : + [convert.(eltype(cache.u0), f.cons_hess_prototype[i]) + for i in 1:num_cons] + expr = symbolify(f.expr) + cons_expr = symbolify.(f.cons_expr) + + return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess, + hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_prototype, + cons_jac_prototype = cons_jac_prototype, + cons_hess_prototype = cons_hess_prototype, + expr = expr, cons_expr = cons_expr, + sys = f.sys, + observed = f.observed) +end function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType, p, num_cons = 0) From bfb28b6aecfe1361bc0e9f633eb17951853c507e Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 22 Jul 2024 20:04:46 -0400 Subject: [PATCH 15/33] Move to no closures and no args --- ext/OptimizationEnzymeExt.jl | 378 ++++++++------------------------- ext/OptimizationZygoteExt.jl | 189 +++-------------- src/OptimizationDIExt.jl | 263 +++++------------------ src/OptimizationDISparseExt.jl | 232 ++++---------------- src/function.jl | 3 +- 5 files changed, 213 insertions(+), 852 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 4439c8b..278a2c5 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -8,8 +8,8 @@ import OptimizationBase.ADTypes: AutoEnzyme using Enzyme using Core: Vararg -@inline function firstapply(f::F, θ, p, args...) where {F} - res = f(θ, p, args...) +@inline function firstapply(f::F, θ, p) where {F} + res = f(θ, p) if isa(res, AbstractFloat) res else @@ -23,46 +23,46 @@ function inner_grad(θ, bθ, f, p, args::Vararg{Any, N}) where {N} Active, Const(f), Enzyme.Duplicated(θ, bθ), - Const(p), - Const.(args)...), + Const(p) + ), return nothing end -function hv_f2_alloc(x, f, p, args...) +function hv_f2_alloc(x, f, p) dx = Enzyme.make_zero(x) Enzyme.autodiff_deferred(Enzyme.Reverse, firstapply, Active, f, Enzyme.Duplicated(x, dx), - Const(p), - Const.(args)...) + Const(p) + ) return dx end function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothing}, num_cons::Int, i::Int, args::Vararg{Any, N}) where {N} res = zeros(eltype(x), num_cons) - fcons(res, x, p, args...) + fcons(res, x, p) return res[i] end function cons_f2(x, dx, fcons, p, num_cons, i, args::Vararg{Any, N}) where {N} Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), - Const(fcons), Const(p), Const(num_cons), Const(i), Const.(args)...) + Const(fcons), Const(p), Const(num_cons), Const(i)) return nothing end function inner_cons_oop( x::Vector{T}, fcons::Function, p::Union{SciMLBase.NullParameters, Nothing}, i::Int, args::Vararg{Any, N}) where {T, N} - return fcons(x, p, args...)[i] + return fcons(x, p)[i] end function cons_f2_oop(x, dx, fcons, p, i, args::Vararg{Any, N}) where {N} Enzyme.autodiff_deferred( Enzyme.Reverse, inner_cons_oop, Active, Enzyme.Duplicated(x, dx), - Const(fcons), Const(p), Const(i), Const.(args)...) + Const(fcons), Const(p), Const(i)) return nothing end @@ -70,24 +70,22 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0) if f.grad === nothing - grad = let - function (res, θ, args...) - Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f.f), - Enzyme.Duplicated(θ, res), - Const(p), - Const.(args)...) - end + function grad(res, θ) + Enzyme.make_zero!(res) + Enzyme.autodiff(Enzyme.Reverse, + Const(firstapply), + Active, + Const(f.f), + Enzyme.Duplicated(θ, res), + Const(p) + ) end else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + grad = (G, θ) -> f.grad(G, θ, p) end if f.hess === nothing - function hess(res, θ, args...) + function hess(res, θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) bθ = zeros(eltype(θ), length(θ)) @@ -98,23 +96,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), - Const(p), - Const.(args)...) + Const(p) + ) for i in eachindex(θ) res[i, :] .= vdbθ[i] end end else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + hess = (H, θ) -> f.hess(H, θ, p) end if f.hv === nothing - hv = function (H, θ, v, args...) + function hv(H, θ, v) H .= Enzyme.autodiff( Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), - Const(_f), Const(f.f), Const(p), - Const.(args)...)[1] + Const(_f), Const(f.f), Const(p) + )[1] end else hv = f.hv @@ -123,20 +121,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if f.cons === nothing cons = nothing else - cons = (res, θ, args...) -> f.cons(res, θ, p, args...) + cons = (res, θ) -> f.cons(res, θ, p) end if cons !== nothing && f.cons_j === nothing seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) y = zeros(eltype(x), num_cons) - cons_j = function (J, θ, args...) - for i in 1:num_cons + function cons_j(J, θ) + for i in 1:length(θ) Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p), Const.(args)...) + BatchDuplicated(θ, seeds), Const(p)) for i in 1:length(θ) if J isa Vector J[i] = Jaccache[i][1] @@ -146,11 +144,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end end else - cons_j = (J, θ, args...) -> f.cons_j(J, θ, p, args...) + cons_j = (J, θ) -> f.cons_j(J, θ, p) end if cons !== nothing && f.cons_h === nothing - cons_h = function (res, θ, args...) + function cons_h(res, θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) bθ = zeros(eltype(θ), length(θ)) vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) @@ -166,9 +164,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Const(f.cons), Const(p), Const(num_cons), - Const(i), - Const.(args)... - ) + Const(i)) for j in eachindex(θ) res[i][j, :] .= vdbθ[j] @@ -176,14 +172,25 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end end else - cons_h = (res, θ, args...) -> f.cons_h(res, θ, p, args...) + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype) + cons_hess_prototype = f.cons_hess_prototype, + lag_h = lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, @@ -191,120 +198,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, adtype::AutoEnzyme, num_cons = 0) p = cache.p + x = cache.u0 - if f.grad === nothing - function grad(res, θ, args...) - Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f.f), - Enzyme.Duplicated(θ, res), - Const(p), - Const.(args)...) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - function hess(res, θ, args...) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - - Enzyme.autodiff(Enzyme.Forward, - inner_grad, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), - Const(f.f), - Const(p), - Const.(args)...) - - for i in eachindex(θ) - res[i, :] .= vdbθ[i] - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - H .= Enzyme.autodiff( - Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), - Const(f.f), Const(p), - Const.(args)...)[1] - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ, args...) -> f.cons(res, θ, p, args...) - end - - if cons !== nothing && f.cons_j === nothing - seeds = Tuple((Array(r) - for r in eachrow(I(length(cache.u0)) * one(eltype(cache.u0))))) - Jaccache = Tuple(zeros(eltype(cache.u0), num_cons) for i in 1:length(cache.u0)) - y = zeros(eltype(cache.u0), num_cons) - cons_j = function (J, θ, args...) - for i in 1:num_cons - Enzyme.make_zero!(Jaccache[i]) - end - Enzyme.make_zero!(y) - Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p), Const.(args)...) - for i in 1:length(θ) - if J isa Vector - J[i] = Jaccache[i][1] - else - copyto!(@view(J[:, i]), Jaccache[i]) - end - end - end - else - cons_j = (J, θ, args...) -> f.cons_j(J, θ, p, args...) - end - - if cons !== nothing && f.cons_h === nothing - cons_h = function (res, θ, args...) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - for i in 1:num_cons - bθ .= zero(eltype(bθ)) - for el in vdbθ - Enzyme.make_zero!(el) - end - Enzyme.autodiff(Enzyme.Forward, - cons_f2, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), - Const(f.cons), - Const(p), - Const(num_cons), - Const(i), - Const.(args)... - ) - - for j in eachindex(θ) - res[i][j, :] .= vdbθ[j] - end - end - end - else - cons_h = (res, θ, args...) -> f.cons_h(res, θ, p, args...) - end - - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype) + return instantiate_function(f, x, adtype, p, num_cons) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, @@ -312,25 +208,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x num_cons = 0) if f.grad === nothing res = zeros(eltype(x), size(x)) - grad = let res = res - function (θ, args...) - Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f.f), - Enzyme.Duplicated(θ, res), - Const(p), - Const.(args)...) - return res - end + function grad(θ) + Enzyme.make_zero!(res) + Enzyme.autodiff(Enzyme.Reverse, + Const(firstapply), + Active, + Const(f.f), + Enzyme.Duplicated(θ, res), + Const(p) + ) + return res end else - grad = (θ, args...) -> f.grad(θ, p, args...) + grad = (θ) -> f.grad(θ, p) end if f.hess === nothing - function hess(θ, args...) + function hess(θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) bθ = zeros(eltype(θ), length(θ)) @@ -341,22 +235,22 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.f), - Const(p), - Const.(args)...) + Const(p) + ) return reduce( vcat, [reshape(vdbθ[i], (1, length(vdbθ[i]))) for i in eachindex(θ)]) end else - hess = (θ, args...) -> f.hess(θ, p, args...) + hess = (θ) -> f.hess(θ, p) end if f.hv === nothing - hv = function (θ, v, args...) + function hv(θ, v) Enzyme.autodiff( Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), - Const(_f), Const(f.f), Const(p), - Const.(args)...)[1] + Const(_f), Const(f.f), Const(p) + )[1] end else hv = f.hv @@ -365,14 +259,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x if f.cons === nothing cons = nothing else - cons_oop = (θ, args...) -> f.cons(θ, p, args...) + cons_oop = (θ) -> f.cons(θ, p) end if f.cons !== nothing && f.cons_j === nothing seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) - cons_j = function (θ, args...) + function cons_j(θ) J = Enzyme.autodiff( - Enzyme.Forward, f.cons, BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1] + Enzyme.Forward, f.cons, BatchDuplicated(θ, seeds), Const(p))[1] if num_cons == 1 return reduce(vcat, J) else @@ -384,7 +278,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x end if f.cons !== nothing && f.cons_h === nothing - cons_h = function (θ, args...) + function cons_h(θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) bθ = zeros(eltype(θ), length(θ)) vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) @@ -400,9 +294,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x Enzyme.BatchDuplicated(bθ, vdbθ), Const(f.cons), Const(p), - Const(i), - Const.(args)... - ) + Const(i)) for j in eachindex(θ) res[i][j, :] = vdbθ[j] end @@ -410,14 +302,25 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return res end else - cons_h = (θ, args...) -> f.cons_h(θ, p, args...) + cons_h = (θ) -> f.cons_h(θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, cons = cons_oop, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype) + cons_hess_prototype = f.cons_hess_prototype, + lag_h = lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, @@ -425,114 +328,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, adtype::AutoEnzyme, num_cons = 0) p = cache.p + x = cache.u0 - if f.grad === nothing - res = zeros(eltype(x), size(x)) - grad = let res = res - function (θ, args...) - Enzyme.make_zero!(res) - Enzyme.autodiff(Enzyme.Reverse, - Const(firstapply), - Active, - Const(f.f), - Enzyme.Duplicated(θ, res), - Const(p), - Const.(args)...) - return res - end - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - function hess(θ, args...) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) - - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - - Enzyme.autodiff(Enzyme.Forward, - inner_grad, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), - Const(f.f), - Const(p), - Const.(args)...) - - return reduce( - vcat, [reshape(vdbθ[i], (1, length(vdbθ[i]))) for i in eachindex(θ)]) - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (θ, v, args...) - Enzyme.autodiff( - Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), - Const(_f), Const(f.f), Const(p), - Const.(args)...)[1] - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons_oop = (θ, args...) -> f.cons(θ, p, args...) - end - - if f.cons !== nothing && f.cons_j === nothing - J = Tuple(zeros(eltype(cache.u0), length(cache.u0)) for i in 1:num_cons) - cons_j = function (θ, args...) - for i in 1:num_cons - Enzyme.make_zero!(J[i]) - end - Enzyme.autodiff( - Enzyme.Forward, f.cons, BatchDuplicated(θ, J), Const(p), Const.(args)...) - return reduce(vcat, reshape.(J, Ref(1), Ref(length(θ)))) - end - else - cons_j = (θ, args...) -> f.cons_j(θ, p, args...) - end - - if f.cons !== nothing && f.cons_h === nothing - cons_h = function (θ, args...) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - res = [zeros(eltype(x), length(θ), length(θ)) for i in 1:num_cons] - for i in 1:num_cons - Enzyme.make_zero!(bθ) - for el in vdbθ - Enzyme.make_zero!(el) - end - Enzyme.autodiff(Enzyme.Forward, - cons_f2_oop, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), - Const(f.cons), - Const(p), - Const(i), - Const.(args)... - ) - for j in eachindex(θ) - res[i][j, :] = vdbθ[j] - end - end - return res - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons_oop, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype) + return instantiate_function(f, x, adtype, p, num_cons) end end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 211db2d..204608c 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -8,9 +8,12 @@ using Zygote, Zygote.ForwardDiff function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoZygote, p, num_cons = 0) - _f = (θ, args...) -> f(θ, p, args...)[1] + function _f(θ, args...) + return f(θ, p, args...)[1] + end + if f.grad === nothing - grad = function (res, θ, args...) + function grad(res, θ, args...) val = Zygote.gradient(x -> _f(x, args...), θ)[1] if val === nothing res .= zero(eltype(θ)) @@ -23,7 +26,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if f.hess === nothing - hess = function (res, θ, args...) + function hess(res, θ, args...) res .= ForwardDiff.jacobian(θ) do θ Zygote.gradient(x -> _f(x, args...), θ)[1] end @@ -33,7 +36,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if f.hv === nothing - hv = function (H, θ, v, args...) + function hv(H, θ, v, args...) _θ = ForwardDiff.Dual.(θ, v) res = similar(_θ) grad(res, _θ, args...) @@ -46,12 +49,19 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if f.cons === nothing cons = nothing else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = Zygote.Buffer(x, num_cons); cons(_res, x); copy(_res)) + function cons(res, θ, args...) + f.cons(res, θ, p, args...) + end + + function cons_oop(x, args...) + _res = Zygote.Buffer(x, num_cons) + cons(_res, x, args...) + copy(_res) + end end if cons !== nothing && f.cons_j === nothing - cons_j = function (J, θ) + function cons_j(J, θ) J .= first(Zygote.jacobian(cons_oop, θ)) end else @@ -62,7 +72,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] cons_h = function (res, θ) for i in 1:num_cons - res[i] .= Zygote.hessian(fncs[i], θ) + res[i] .= ForwardDiff.jacobian(θ) do θ + Zygote.gradient(fncs[i], θ)[1] + end end end else @@ -90,83 +102,10 @@ end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::AutoZygote, num_cons = 0) - _f = (θ, args...) -> f(θ, cache.p, args...)[1] - if f.grad === nothing - grad = function (res, θ, args...) - val = Zygote.gradient(x -> _f(x, args...), θ)[1] - if val === nothing - res .= zero(eltype(θ)) - else - res .= val - end - end - else - grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) - end - - if f.hess === nothing - hess = function (res, θ, args...) - res .= ForwardDiff.jacobian(θ) do θ - Zygote.gradient(x -> _f(x, args...), θ)[1] - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, cache.p) - cons_oop = (x) -> (_res = Zygote.Buffer(x, num_cons); cons(_res, x); copy(_res)) - end - - if cons !== nothing && f.cons_j === nothing - cons_j = function (J, θ) - J .= first(Zygote.jacobian(cons_oop, θ)) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - res[i] .= Zygote.hessian(fncs[i], θ) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) - end + x = cache.u0 + p = cache.p - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) + return instantiate_function(f, x, adtype, p, num_cons) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, @@ -257,88 +196,10 @@ end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, adtype::AutoZygote, num_cons = 0) - _f = (θ, args...) -> f(θ, cache.p, args...)[1] p = cache.p + x = cache.u0 - if f.grad === nothing - grad = function (θ, args...) - val = Zygote.gradient(x -> _f(x, args...), θ)[1] - if val === nothing - return zero(eltype(θ)) - else - return val - end - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - hess = function (θ, args...) - return ForwardDiff.jacobian(θ) do θ - Zygote.gradient(x -> _f(x, args...), θ)[1] - end - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = grad(_θ, args...) - return getindex.(ForwardDiff.partials.(res), 1) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - if cons !== nothing && f.cons_j === nothing - cons_j = function (θ) - if num_cons > 1 - return first(Zygote.jacobian(cons_oop, θ)) - else - return first(Zygote.jacobian(cons_oop, θ))[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - cons_h = function (θ) - return map(1:num_cons) do i - Zygote.hessian(fncs[i], θ) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) + return instantiate_function(f, x, adtype, p, num_cons) end end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index e185d75..027cb95 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -9,11 +9,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, hvp, jacobian using ADTypes, SciMLBase -function instantiate_function( - f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, - p = SciMLBase.NullParameters(), num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - +function generate_adtype(adtype) if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode @@ -21,6 +17,17 @@ function instantiate_function( else soadtype = adtype end + return adtype, soadtype +end + +function instantiate_function( + f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, + p = SciMLBase.NullParameters(), num_cons = 0) + function _f(θ) + return f(θ, p)[1] + end + + adtype, soadtype = generate_adtype(adtype) if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) @@ -28,23 +35,23 @@ function instantiate_function( gradient!(_f, res, adtype, θ, extras_grad) end else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + grad = (G, θ) -> f.grad(G, θ, p) end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) - function hess(res, θ, args...) + function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + hess = (H, θ) -> f.hess(H, θ, p) end if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) - hv = function (H, θ, v, args...) + hv = function (H, θ, v) hvp!(_f, H, soadtype, θ, v, extras_hvp) end else @@ -54,15 +61,22 @@ function instantiate_function( if f.cons === nothing cons = nothing else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + function cons(res, θ) + return f.cons(res, θ, p) + end + + function cons_oop(x) + _res = zeros(eltype(x), num_cons) + cons(_res, x) + return _res + end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (J, θ) + function cons_j(J, θ) jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) @@ -92,6 +106,7 @@ function instantiate_function( else lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = hess_sparsity, @@ -100,7 +115,11 @@ function instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function instantiate_function( @@ -108,133 +127,42 @@ function instantiate_function( adtype::ADTypes.AbstractADType, num_cons = 0) x = cache.u0 p = cache.p - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode - soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) - else - soadtype = adtype - end - if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) - function grad(res, θ) - gradient!(_f, res, adtype, θ, extras_grad) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - extras_hess = prepare_hessian(_f, soadtype, x) - function hess(res, θ, args...) - hessian!(_f, res, soadtype, θ, extras_hess) - end - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) - hv = function (H, θ, v, args...) - hvp!(_f, H, soadtype, θ, v, extras_hvp) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (J, θ) - jacobian!(cons_oop, J, adtype, θ, extras_jac) - if size(J, 1) == 1 - J = vec(J) - end - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - - function cons_h(H, θ) - for i in 1:num_cons - hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + return instantiate_function(f, x, adtype, p, num_cons) end function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode - soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) - else - soadtype = adtype + function _f(θ) + return f(θ, p)[1] end + adtype, soadtype = generate_adtype(adtype) + if f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) function grad(θ) gradient(_f, adtype, θ, extras_grad) end else - grad = (θ, args...) -> f.grad(θ, p, args...) + grad = (θ) -> f.grad(θ, p) end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better - function hess(θ, args...) + function hess(θ) hessian(_f, soadtype, θ, extras_hess) end else - hess = (θ, args...) -> f.hess(θ, p, args...) + hess = (θ) -> f.hess(θ, p) end if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) - hv = function (θ, v, args...) + function hv(θ, v) hvp(_f, soadtype, θ, v, extras_hvp) end else @@ -244,16 +172,17 @@ function instantiate_function( if f.cons === nothing cons = nothing else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons + function cons(θ) + return f.cons(θ, p) + end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons_oop, adtype, x) + extras_jac = prepare_jacobian(cons, adtype, x) cons_j = function (θ) - J = jacobian(cons_oop, adtype, θ, extras_jac) + J = jacobian(cons, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end @@ -266,7 +195,7 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + fncs = [(x) -> cons(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) function cons_h(θ) @@ -284,6 +213,7 @@ function instantiate_function( else lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = hess_sparsity, @@ -292,7 +222,11 @@ function instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function instantiate_function( @@ -300,95 +234,6 @@ function instantiate_function( adtype::ADTypes.AbstractADType, num_cons = 0) x = cache.u0 p = cache.p - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode - soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote? - elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode - soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype) - else - soadtype = adtype - end - if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) - function grad(θ) - gradient(_f, adtype, θ, extras_grad) - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better - function hess(θ, args...) - hessian(_f, soadtype, θ, extras_hess) - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) - hv = function (θ, v, args...) - hvp(_f, soadtype, θ, v, extras_hvp) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (θ) - J = jacobian(cons_oop, adtype, θ, extras_jac) - if size(J, 1) == 1 - J = vec(J) - end - return J - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - - function cons_h(θ) - H = map(1:num_cons) do i - hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) - end - return H - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + return instantiate_function(f, x, adtype, p, num_cons) end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 375a903..b0510f5 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -21,7 +21,7 @@ function generate_sparse_adtype(adtype) sparsity_detector = TracerSparsityDetector(), coloring_algorithm = GreedyColoringAlgorithm()) elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerSparsityDetector(), @@ -43,7 +43,7 @@ function generate_sparse_adtype(adtype) sparsity_detector = TracerSparsityDetector(), coloring_algorithm = adtype.coloring_algorithm) elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = TracerSparsityDetector(), @@ -65,7 +65,7 @@ function generate_sparse_adtype(adtype) sparsity_detector = adtype.sparsity_detector, coloring_algorithm = GreedyColoringAlgorithm()) elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, @@ -84,7 +84,7 @@ function generate_sparse_adtype(adtype) sparsity_detector = adtype.sparsity_detector, coloring_algorithm = adtype.coloring_algorithm) elseif !(adtype.dense_ad isa SciMLBase.NoAD) && - ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode + ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode soadtype = AutoSparse( DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()), sparsity_detector = adtype.sparsity_detector, @@ -103,7 +103,9 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) + function _f(θ) + return f(θ, p)[1] + end adtype, soadtype = generate_sparse_adtype(adtype) @@ -113,23 +115,23 @@ function instantiate_function( gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) + grad = (G, θ) -> f.grad(G, θ, p) end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better - function hess(res, θ, args...) + function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) + hess = (H, θ) -> f.hess(H, θ, p) end if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) - hv = function (H, θ, v, args...) + hv = function (H, θ, v) hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end else @@ -139,15 +141,22 @@ function instantiate_function( if f.cons === nothing cons = nothing else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + function cons(res, θ) + f.cons(res, θ, p) + end + + function cons_oop(x) + _res = zeros(eltype(x), num_cons) + cons(_res, x) + return _res + end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (J, θ) + function cons_j(J, θ) jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) @@ -185,7 +194,11 @@ function instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function instantiate_function( @@ -193,95 +206,16 @@ function instantiate_function( adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 p = cache.p - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - adtype, soadtype = generate_sparse_adtype(adtype) - - if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype.dense_ad, x) - function grad(res, θ) - gradient!(_f, res, adtype.dense_ad, θ, extras_grad) - end - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - extras_hess = prepare_hessian(_f, soadtype, x) - function hess(res, θ, args...) - hessian!(_f, res, soadtype, θ, extras_hess) - end - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) - hv = function (H, θ, v, args...) - hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (res, θ) -> f.cons(res, θ, p) - cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (J, θ) - jacobian!(cons_oop, J, adtype, θ, extras_jac) - if size(J, 1) == 1 - J = vec(J) - end - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - - function cons_h(H, θ) - for i in 1:num_cons - hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + return instantiate_function(f, x, adtype, p, num_cons) end function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) + function _f(θ) + return f(θ, p)[1] + end adtype, soadtype = generate_sparse_adtype(adtype) @@ -291,23 +225,23 @@ function instantiate_function( gradient(_f, adtype.dense_ad, θ, extras_grad) end else - grad = (θ, args...) -> f.grad(θ, p, args...) + grad = (θ) -> f.grad(θ, p) end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better - function hess(θ, args...) + function hess(θ) hessian(_f, soadtype, θ, extras_hess) end else - hess = (θ, args...) -> f.hess(θ, p, args...) + hess = (θ) -> f.hess(θ, p) end if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) - hv = function (θ, v, args...) + function hv(θ, v) hvp(_f, soadtype.dense_ad, θ, v, extras_hvp) end else @@ -317,15 +251,16 @@ function instantiate_function( if f.cons === nothing cons = nothing else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons + function cons(θ) + f.cons(θ, p) + end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (θ) + function cons_j(θ) J = jacobian(cons_oop, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) @@ -365,7 +300,11 @@ function instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) end function instantiate_function( @@ -373,89 +312,6 @@ function instantiate_function( adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) x = cache.u0 p = cache.p - _f = (θ, args...) -> first(f.f(θ, p, args...)) - - adtype, soadtype = generate_sparse_adtype(adtype) - - if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype.dense_ad, x) - function grad(θ) - gradient(_f, adtype.dense_ad, θ, extras_grad) - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - hess_sparsity = f.hess_prototype - hess_colors = f.hess_colorvec - if f.hess === nothing - extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better - function hess(θ, args...) - hessian(_f, soadtype, θ, extras_hess) - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) - hv = function (θ, v, args...) - hvp(_f, soadtype.dense_ad, θ, v, extras_hvp) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - cons_jac_prototype = f.cons_jac_prototype - cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons_oop, adtype, x) - cons_j = function (θ) - J = jacobian(cons_oop, adtype, θ, extras_jac) - if size(J, 1) == 1 - J = vec(J) - end - return J - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - conshess_sparsity = f.cons_hess_prototype - conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - - function cons_h(θ) - H = map(1:num_cons) do i - hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) - end - return H - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_sparsity, - hess_colorvec = hess_colors, - cons_jac_prototype = cons_jac_prototype, - cons_jac_colorvec = cons_jac_colorvec, - cons_hess_prototype = conshess_sparsity, - cons_hess_colorvec = conshess_colors, - lag_h, f.lag_hess_prototype) + return instantiate_function(f, x, adtype, p, num_cons) end diff --git a/src/function.jl b/src/function.jl index 257d680..63e194f 100644 --- a/src/function.jl +++ b/src/function.jl @@ -72,7 +72,8 @@ function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD observed = f.observed) end -function instantiate_function(f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD, +function instantiate_function( + f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD, num_cons = 0) grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...) hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...) From 591e87c4e7b18e9b947ab773b6b238852beb622f Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 22 Jul 2024 20:21:24 -0400 Subject: [PATCH 16/33] cons_oop oops --- src/OptimizationDISparseExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index b0510f5..0531c5d 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -259,9 +259,9 @@ function instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons_oop, adtype, x) + extras_jac = prepare_jacobian(cons, adtype, x) function cons_j(θ) - J = jacobian(cons_oop, adtype, θ, extras_jac) + J = jacobian(cons, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end @@ -274,7 +274,7 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + fncs = [(x) -> cons(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) function cons_h(θ) From 862be5856c566017c60a50bb5715078a71b53ada Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 22 Jul 2024 22:40:19 -0400 Subject: [PATCH 17/33] clarify import --- ext/OptimizationEnzymeExt.jl | 4 ++-- ext/OptimizationZygoteExt.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 278a2c5..8f906b2 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -200,7 +200,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, p = cache.p x = cache.u0 - return instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, @@ -330,7 +330,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, p = cache.p x = cache.u0 - return instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 204608c..4453beb 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -105,7 +105,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, @@ -199,7 +199,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, p = cache.p x = cache.u0 - return instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end end From eb742978003fc388dcabbf0c835aacd9001ec8fd Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 30 Jul 2024 23:22:07 -0400 Subject: [PATCH 18/33] generate a sparse lagrange hessian and some more things --- Project.toml | 9 ++-- ext/OptimizationEnzymeExt.jl | 63 +++++++++++++++++++++++----- src/OptimizationBase.jl | 2 + src/OptimizationDIExt.jl | 4 +- src/OptimizationDISparseExt.jl | 75 ++++++++++++++++++++++++++++++---- src/augmented_lagrangian.jl | 13 ++++++ 6 files changed, 142 insertions(+), 24 deletions(-) create mode 100644 src/augmented_lagrangian.jl diff --git a/Project.toml b/Project.toml index 0e35b4c..918aadd 100644 --- a/Project.toml +++ b/Project.toml @@ -6,24 +6,25 @@ version = "1.3.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" -SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 8f906b2..862804d 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -17,7 +17,7 @@ using Core: Vararg end end -function inner_grad(θ, bθ, f, p, args::Vararg{Any, N}) where {N} +function inner_grad(θ, bθ, f, p) Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Active, @@ -28,6 +28,16 @@ function inner_grad(θ, bθ, f, p, args::Vararg{Any, N}) where {N} return nothing end +function inner_grad_primal(θ, bθ, f, p) + Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal, + Const(firstapply), + Active, + Const(f), + Enzyme.Duplicated(θ, bθ), + Const(p) + )[2] +end + function hv_f2_alloc(x, f, p) dx = Enzyme.make_zero(x) Enzyme.autodiff_deferred(Enzyme.Reverse, @@ -41,13 +51,13 @@ function hv_f2_alloc(x, f, p) end function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothing}, - num_cons::Int, i::Int, args::Vararg{Any, N}) where {N} + num_cons::Int, i::Int) res = zeros(eltype(x), num_cons) fcons(res, x, p) return res[i] end -function cons_f2(x, dx, fcons, p, num_cons, i, args::Vararg{Any, N}) where {N} +function cons_f2(x, dx, fcons, p, num_cons, i) Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), Const(fcons), Const(p), Const(num_cons), Const(i)) return nothing @@ -55,11 +65,11 @@ end function inner_cons_oop( x::Vector{T}, fcons::Function, p::Union{SciMLBase.NullParameters, Nothing}, - i::Int, args::Vararg{Any, N}) where {T, N} + i::Int) where {T} return fcons(x, p)[i] end -function cons_f2_oop(x, dx, fcons, p, i, args::Vararg{Any, N}) where {N} +function cons_f2_oop(x, dx, fcons, p, i) Enzyme.autodiff_deferred( Enzyme.Reverse, inner_cons_oop, Active, Enzyme.Duplicated(x, dx), Const(fcons), Const(p), Const(i)) @@ -84,17 +94,38 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, grad = (G, θ) -> f.grad(G, θ, p) end + function fg!(res, θ) + Enzyme.make_zero!(res) + y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + Const(firstapply), + Active, + Const(f.f), + Enzyme.Duplicated(θ, res), + Const(p) + )[2] + return y + end + if f.hess === nothing - function hess(res, θ) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) + vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + bθ = zeros(eltype(x), length(x)) + + if f.hess_prototype === nothing + vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + else + #useless right now, looks like there is no way to tell Enzyme the sparsity pattern? + vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype))) + end - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) + function hess(res, θ) + Enzyme.make_zero!.(vdθ) + Enzyme.make_zero!(bθ) + Enzyme.make_zero!.(vdbθ) Enzyme.autodiff(Enzyme.Forward, inner_grad, Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), + Enzyme.BatchDuplicatedNoNeed(bθ, vdbθ), Const(f.f), Const(p) ) @@ -107,6 +138,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, hess = (H, θ) -> f.hess(H, θ, p) end + function fgh!(G, H, θ) + + end + if f.hv === nothing function hv(H, θ, v) H .= Enzyme.autodiff( @@ -147,6 +182,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, cons_j = (J, θ) -> f.cons_j(J, θ, p) end + if cons !== nothing && f.cons_vjp === nothing + function cons_vjp(res, θ, v) + + end + else + cons_vjp = (θ, σ) -> f.cons_vjp(θ, σ, p) + end + if cons !== nothing && f.cons_h === nothing function cons_h(res, θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) diff --git a/src/OptimizationBase.jl b/src/OptimizationBase.jl index 52528bf..8a1c3e4 100644 --- a/src/OptimizationBase.jl +++ b/src/OptimizationBase.jl @@ -19,6 +19,8 @@ import SciMLBase: OptimizationProblem, MaxSense, MinSense, OptimizationStats export ObjSense, MaxSense, MinSense +using FastClosures + struct NullCallback end (x::NullCallback)(args...) = false const DEFAULT_CALLBACK = NullCallback() diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 027cb95..2f02994 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -50,7 +50,7 @@ function instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) hv = function (H, θ, v) hvp!(_f, H, soadtype, θ, v, extras_hvp) end @@ -161,7 +161,7 @@ function instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) function hv(θ, v) hvp(_f, soadtype, θ, v, extras_hvp) end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 0531c5d..b1cb8da 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -125,12 +125,14 @@ function instantiate_function( function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end + hess_sparsity = extras_hess.sparsity + hess_colors = extras_hess.colors else hess = (H, θ) -> f.hess(H, θ, p) end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) hv = function (H, θ, v) hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end @@ -162,6 +164,8 @@ function instantiate_function( J = vec(J) end end + cons_jac_prototype = extras_jac.sparsity + cons_jac_colorvec = extras_jac.colors else cons_j = (J, θ) -> f.cons_j(J, θ, p) end @@ -169,20 +173,75 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - + fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = Vector{DifferentiationInterface.SparseHessianExtras}(undef, length(fncs)) + for ind in 1:num_cons + extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) + end + conshess_sparsity = [sum(sparse, cons)] + conshess_colors = getfield.(extras_cons_hess, Ref(:colors)) function cons_h(H, θ) for i in 1:num_cons - hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) + hessian!(fncs[i], H[i], soadtype, θ) end end else cons_h = (res, θ) -> f.cons_h(res, θ, p) end + function lagrangian(x, σ = one(eltype(x))) + θ = x[1:end-num_cons] + λ = x[end-num_cons+1:end] + return σ * _f(θ) + dot(λ, cons_oop(θ)) + end + if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_extras = prepare_hessian(lagrangian, soadtype, vcat(x, ones(eltype(x), num_cons))) + lag_hess_prototype = lag_extras.sparsity + + function lag_h(H::Matrix, θ, σ, λ) + @show size(H) + @show size(θ) + @show size(λ) + if σ == zero(eltype(θ)) + cons_h(H, θ) + H *= λ + else + hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras) + end + end + + function lag_h(h, θ, σ, λ) + # @show h + sparseHproto = findnz(lag_extras.sparsity) + H = sparse(sparseHproto[1], sparseHproto[2], zeros(eltype(θ), length(sparseHproto[1]))) + if σ == zero(eltype(θ)) + cons_h(H, θ) + H *= λ + else + hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras) + k = 0 + rows, cols, _ = findnz(H) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] = σ * H[i, j] + end + end + k = 0 + for λi in λ + if Hi isa SparseMatrixCSC + rows, cols, _ = findnz(Hi) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] += λi * Hi[i, j] + end + end + end + end + end + end else lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) end @@ -195,7 +254,7 @@ function instantiate_function( cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, lag_h, - lag_hess_prototype = f.lag_hess_prototype, + lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) @@ -240,7 +299,7 @@ function instantiate_function( end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, rand(size(x))) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) function hv(θ, v) hvp(_f, soadtype.dense_ad, θ, v, extras_hvp) end diff --git a/src/augmented_lagrangian.jl b/src/augmented_lagrangian.jl new file mode 100644 index 0000000..88bca00 --- /dev/null +++ b/src/augmented_lagrangian.jl @@ -0,0 +1,13 @@ +function generate_auglag(θ) + x = cache.f(θ, cache.p) + cons_tmp .= zero(eltype(θ)) + cache.f.cons(cons_tmp, θ) + cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] + cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] + opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + if cache.callback(opt_state, x...) + error("Optimization halted by callback.") + end + return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2) +end \ No newline at end of file From 6d4f6d79b0a0f93d6a8b91c2f87f2cbeb1279368 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 2 Aug 2024 10:49:10 -0400 Subject: [PATCH 19/33] hessian of lagrangian --- ext/OptimizationEnzymeExt.jl | 49 +++++++++++++++++++++++-- src/OptimizationDISparseExt.jl | 65 +++++++++++----------------------- 2 files changed, 68 insertions(+), 46 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 862804d..c6e3a34 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -3,7 +3,7 @@ module OptimizationEnzymeExt import OptimizationBase, OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.SciMLBase -import OptimizationBase.LinearAlgebra: I +import OptimizationBase.LinearAlgebra: I, dot import OptimizationBase.ADTypes: AutoEnzyme using Enzyme using Core: Vararg @@ -76,6 +76,18 @@ function cons_f2_oop(x, dx, fcons, p, i) return nothing end +function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))::Float64 + res = zeros(eltype(x), length(λ)) + cons(res, x, p) + return σ * _f(x, p) + dot(λ, res) +end + +function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ) + Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx), + Const(_f), Const(cons), Const(p), Const(λ), Const(σ)) + return nothing +end + function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0) @@ -219,7 +231,40 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + lag_bθ = zeros(eltype(x), length(x)) + + if f.hess_prototype === nothing + lag_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + else + #useless right now, looks like there is no way to tell Enzyme the sparsity pattern? + lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype))) + end + + function lag_h(h, θ, σ, μ) + Enzyme.make_zero!.(lag_vdθ) + Enzyme.make_zero!(lag_bθ) + Enzyme.make_zero!.(lag_vdbθ) + + Enzyme.autodiff(Enzyme.Forward, + lag_grad, + Enzyme.BatchDuplicated(θ, lag_vdθ), + Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), + Const(lagrangian), + Const(f.f), + Const(f.cons), + Const(p), + Const(σ), + Const(μ) + ) + k = 0 + + for i in eachindex(θ) + vec_lagv = lag_vdbθ[i] + h[k+1:k+i] .= @view(vec_lagv[1:i]) + k += i + end + end else lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index b1cb8da..bf1ebb2 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -174,12 +174,12 @@ function instantiate_function( conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = Vector{DifferentiationInterface.SparseHessianExtras}(undef, length(fncs)) - for ind in 1:num_cons - extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) - end - conshess_sparsity = [sum(sparse, cons)] - conshess_colors = getfield.(extras_cons_hess, Ref(:colors)) + # extras_cons_hess = Vector(undef, length(fncs)) + # for ind in 1:num_cons + # extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) + # end + # conshess_sparsity = getfield.(extras_cons_hess, :sparsity) + # conshess_colors = getfield.(extras_cons_hess, :colors) function cons_h(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ) @@ -189,56 +189,33 @@ function instantiate_function( cons_h = (res, θ) -> f.cons_h(res, θ, p) end - function lagrangian(x, σ = one(eltype(x))) - θ = x[1:end-num_cons] - λ = x[end-num_cons+1:end] - return σ * _f(θ) + dot(λ, cons_oop(θ)) + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) end + lag_hess_prototype = f.lag_hess_prototype if f.lag_h === nothing - lag_extras = prepare_hessian(lagrangian, soadtype, vcat(x, ones(eltype(x), num_cons))) + lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = lag_extras.sparsity - - function lag_h(H::Matrix, θ, σ, λ) - @show size(H) - @show size(θ) - @show size(λ) + + function lag_h(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) cons_h(H, θ) H *= λ else - hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) end end function lag_h(h, θ, σ, λ) - # @show h - sparseHproto = findnz(lag_extras.sparsity) - H = sparse(sparseHproto[1], sparseHproto[2], zeros(eltype(θ), length(sparseHproto[1]))) - if σ == zero(eltype(θ)) - cons_h(H, θ) - H *= λ - else - hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras) - k = 0 - rows, cols, _ = findnz(H) - for (i, j) in zip(rows, cols) - if i <= j - k += 1 - h[k] = σ * H[i, j] - end - end - k = 0 - for λi in λ - if Hi isa SparseMatrixCSC - rows, cols, _ = findnz(Hi) - for (i, j) in zip(rows, cols) - if i <= j - k += 1 - h[k] += λi * Hi[i, j] - end - end - end + H = eltype(θ).(lag_hess_prototype) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + k = 0 + rows, cols, _ = findnz(H) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] = H[i, j] end end end From 9813fd45928b0cbd2916738fa3e9378ff5559514 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 4 Aug 2024 09:10:11 -0400 Subject: [PATCH 20/33] vjp and jvp and jacobian based on dimensions --- ext/OptimizationEnzymeExt.jl | 118 +++++++++++++++++++++++++---------- test/adtests.jl | 3 +- 2 files changed, 87 insertions(+), 34 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index c6e3a34..902a2e7 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -24,7 +24,7 @@ function inner_grad(θ, bθ, f, p) Const(f), Enzyme.Duplicated(θ, bθ), Const(p) - ), + ) return nothing end @@ -89,8 +89,7 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoEnzyme, p, - num_cons = 0) + adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false,) if f.grad === nothing function grad(res, θ) Enzyme.make_zero!(res) @@ -106,16 +105,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, grad = (G, θ) -> f.grad(G, θ, p) end - function fg!(res, θ) - Enzyme.make_zero!(res) - y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, - Const(firstapply), - Active, - Const(f.f), - Enzyme.Duplicated(θ, res), - Const(p) - )[2] - return y + if fg == true + function fg!(res, θ) + Enzyme.make_zero!(res) + y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + Const(firstapply), + Active, + Const(f.f), + Enzyme.Duplicated(θ, res), + Const(p) + )[2] + return y + end + else + fg! = nothing end if f.hess === nothing @@ -130,7 +133,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end function hess(res, θ) - Enzyme.make_zero!.(vdθ) Enzyme.make_zero!(bθ) Enzyme.make_zero!.(vdbθ) @@ -150,8 +152,25 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, hess = (H, θ) -> f.hess(H, θ, p) end - function fgh!(G, H, θ) - + if fgh == true + function fgh!(G, H, θ) + vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) + vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) + + Enzyme.autodiff(Enzyme.Forward, + inner_grad, + Enzyme.BatchDuplicated(θ, vdθ), + Enzyme.BatchDuplicatedNoNeed(G, vdbθ), + Const(f.f), + Const(p) + ) + + for i in eachindex(θ) + H[i, :] .= vdbθ[i] + end + end + else + fgh! = nothing end if f.hv === nothing @@ -175,13 +194,19 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) y = zeros(eltype(x), num_cons) + function cons_j(J, θ) for i in 1:length(θ) Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) - Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), - BatchDuplicated(θ, seeds), Const(p)) + if num_cons > length(θ) + Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), + BatchDuplicated(θ, seeds), Const(p)) + else + Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds), + BatchDuplicated(θ, Jaccache), Const(p)) + end for i in 1:length(θ) if J isa Vector J[i] = Jaccache[i][1] @@ -194,35 +219,63 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, cons_j = (J, θ) -> f.cons_j(J, θ, p) end - if cons !== nothing && f.cons_vjp === nothing - function cons_vjp(res, θ, v) - + if cons !== nothing && f.cons_vjp == true + cons_res = zeros(eltype(x), num_cons) + function cons_vjp!(res, θ, v) + Enzyme.make_zero!(res) + Enzyme.make_zero!(cons_res) + + Enzyme.autodiff(Enzyme.Reverse, + f.cons, + Const, + Duplicated(cons_res, v), + Duplicated(θ, res), + Const(p), + ) + end + else + cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p) + end + + if cons !== nothing && f.cons_jvp == true + cons_res = zeros(eltype(x), num_cons) + + function cons_jvp!(res, θ, v) + Enzyme.make_zero!(res) + Enzyme.make_zero!(cons_res) + + Enzyme.autodiff(Enzyme.Forward, + f.cons, + Duplicated(cons_res, res), + Duplicated(θ, v), + Const(p), + ) end else - cons_vjp = (θ, σ) -> f.cons_vjp(θ, σ, p) + cons_vjp! = nothing end if cons !== nothing && f.cons_h === nothing + cons_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + cons_bθ = zeros(eltype(x), length(x)) + cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + function cons_h(res, θ) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) + Enzyme.make_zero!(cons_bθ) + Enzyme.make_zero!.(cons_vdbθ) + for i in 1:num_cons - bθ .= zero(eltype(bθ)) - for el in vdbθ - Enzyme.make_zero!(el) - end Enzyme.autodiff(Enzyme.Forward, cons_f2, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), + Enzyme.BatchDuplicated(θ, cons_vdθ), + Enzyme.BatchDuplicated(bθ, cons_vdbθ), Const(f.cons), Const(p), Const(num_cons), Const(i)) for j in eachindex(θ) - res[i][j, :] .= vdbθ[j] + res[i][j, :] .= cons_vdbθ[j] end end end @@ -242,7 +295,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end function lag_h(h, θ, σ, μ) - Enzyme.make_zero!.(lag_vdθ) Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) diff --git a/test/adtests.jl b/test/adtests.jl index d0fb1ec..c743bbf 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -26,7 +26,7 @@ H2 = Array{Float64}(undef, 2, 2) g!(G1, x0) h!(H1, x0) -cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2]) +cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2]; return nothing) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoModelingToolkit(), cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoModelingToolkit(), @@ -47,6 +47,7 @@ optprob.cons_h(H3, x0) function con2_c(res, x, p) res .= [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]] + return nothing end optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoModelingToolkit(), From f17c54c42d02a866c6a70c23e51a96543aa19f71 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 4 Aug 2024 16:24:01 -0400 Subject: [PATCH 21/33] jvp vjp with DI --- ext/OptimizationZygoteExt.jl | 202 +-------------------------------- src/OptimizationDIExt.jl | 87 ++++++++++++-- src/OptimizationDISparseExt.jl | 65 +++++++++-- 3 files changed, 133 insertions(+), 221 deletions(-) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 4453beb..2cf6ac1 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -1,205 +1,5 @@ module OptimizationZygoteExt -import OptimizationBase -import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.ADTypes: AutoZygote -using Zygote, Zygote.ForwardDiff - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoZygote, p, - num_cons = 0) - function _f(θ, args...) - return f(θ, p, args...)[1] - end - - if f.grad === nothing - function grad(res, θ, args...) - val = Zygote.gradient(x -> _f(x, args...), θ)[1] - if val === nothing - res .= zero(eltype(θ)) - else - res .= val - end - end - else - grad = (G, θ, args...) -> f.grad(G, θ, p, args...) - end - - if f.hess === nothing - function hess(res, θ, args...) - res .= ForwardDiff.jacobian(θ) do θ - Zygote.gradient(x -> _f(x, args...), θ)[1] - end - end - else - hess = (H, θ, args...) -> f.hess(H, θ, p, args...) - end - - if f.hv === nothing - function hv(H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = similar(_θ) - grad(res, _θ, args...) - H .= getindex.(ForwardDiff.partials.(res), 1) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - function cons(res, θ, args...) - f.cons(res, θ, p, args...) - end - - function cons_oop(x, args...) - _res = Zygote.Buffer(x, num_cons) - cons(_res, x, args...) - copy(_res) - end - end - - if cons !== nothing && f.cons_j === nothing - function cons_j(J, θ) - J .= first(Zygote.jacobian(cons_oop, θ)) - end - else - cons_j = (J, θ) -> f.cons_j(J, θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - cons_h = function (res, θ) - for i in 1:num_cons - res[i] .= ForwardDiff.jacobian(θ) do θ - Zygote.gradient(fncs[i], θ)[1] - end - end - end - else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) - end - - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, - cache::OptimizationBase.ReInitCache, - adtype::AutoZygote, num_cons = 0) - x = cache.u0 - p = cache.p - - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoZygote, p, - num_cons = 0) - _f = (θ, args...) -> f(θ, p, args...)[1] - if f.grad === nothing - grad = function (θ, args...) - val = Zygote.gradient(x -> _f(x, args...), θ)[1] - if val === nothing - return zero(eltype(θ)) - else - return val - end - end - else - grad = (θ, args...) -> f.grad(θ, p, args...) - end - - if f.hess === nothing - hess = function (θ, args...) - return ForwardDiff.jacobian(θ) do θ - return Zygote.gradient(x -> _f(x, args...), θ)[1] - end - end - else - hess = (θ, args...) -> f.hess(θ, p, args...) - end - - if f.hv === nothing - hv = function (H, θ, v, args...) - _θ = ForwardDiff.Dual.(θ, v) - res = grad(_θ, args...) - return getindex.(ForwardDiff.partials.(res), 1) - end - else - hv = f.hv - end - - if f.cons === nothing - cons = nothing - else - cons = (θ) -> f.cons(θ, p) - cons_oop = cons - end - - if cons !== nothing && f.cons_j === nothing - cons_j = function (θ) - if num_cons > 1 - return first(Zygote.jacobian(cons_oop, θ)) - else - return first(Zygote.jacobian(cons_oop, θ))[1, :] - end - end - else - cons_j = (θ) -> f.cons_j(θ, p) - end - - if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] - cons_h = function (θ) - return map(1:num_cons) do i - Zygote.hessian(fncs[i], θ) - end - end - else - cons_h = (θ) -> f.cons_h(θ, p) - end - - if f.lag_h === nothing - lag_h = nothing # Consider implementing this - else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - end - - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = f.hess_prototype, - cons_jac_prototype = f.cons_jac_prototype, - cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, - lag_hess_prototype = f.lag_hess_prototype, - sys = f.sys, - expr = f.expr, - cons_expr = f.cons_expr) -end - -function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, - cache::OptimizationBase.ReInitCache, - adtype::AutoZygote, num_cons = 0) - p = cache.p - x = cache.u0 - - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) -end +using DifferentiationInterface, Zygote end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 2f02994..a975ccc 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -22,7 +22,8 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, - p = SciMLBase.NullParameters(), num_cons = 0) + p = SciMLBase.NullParameters(), num_cons = 0; + fg = false, fgh = false, cons_vjp = false, cons_jvp = false) function _f(θ) return f(θ, p)[1] end @@ -38,6 +39,13 @@ function instantiate_function( grad = (G, θ) -> f.grad(G, θ, p) end + if fg == true + function fg!(res, θ) + (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) + return y + end + end + hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing @@ -49,6 +57,13 @@ function instantiate_function( hess = (H, θ) -> f.hess(H, θ, p) end + if fgh == true + function fgh!(G, H, θ) + (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + return y + end + end + if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) hv = function (H, θ, v) @@ -65,10 +80,18 @@ function instantiate_function( return f.cons(res, θ, p) end - function cons_oop(x) - _res = zeros(eltype(x), num_cons) - cons(_res, x) - return _res + if adtype isa AutoZygote && Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote") in keys(Base.loaded_modules) + function cons_oop(x) + _res = Zygote.Buffer(x, num_cons) + cons(_res, x) + copy(_res) + end + else + function cons_oop(x) + _res = zeros(eltype(x), num_cons) + cons(_res, x) + return _res + end end end @@ -86,6 +109,24 @@ function instantiate_function( cons_j = (J, θ) -> f.cons_j(J, θ, p) end + if f.cons_vjp === nothing && cons_vjp == true + extras_pullback = prepare_pullback(cons_oop, adtype, x) + function cons_vjp!(J, θ, v) + pullback!(cons_oop, J, adtype, θ, v, extras_pullback) + end + else + cons_vjp! = nothing + end + + if f.cons_jvp === nothing && cons_jvp == true + extras_pushforward = prepare_pushforward(cons_oop, adtype, x) + function cons_jvp!(J, θ, v) + pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward) + end + else + cons_jvp! = nothing + end + conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing @@ -101,14 +142,44 @@ function instantiate_function( cons_h = (res, θ) -> f.cons_h(res, θ, p) end + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) + end + + lag_hess_prototype = f.lag_hess_prototype + if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_extras = prepare_hessian(lagrangian, soadtype, x) + lag_hess_prototype = zeros(Bool, length(x), length(x)) + + function lag_h(H::AbstractMatrix, θ, σ, λ) + if σ == zero(eltype(θ)) + cons_h(H, θ) + H *= λ + else + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + end + end + + function lag_h(h, θ, σ, λ) + H = eltype(θ).(lag_hess_prototype) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + k = 0 + rows, cols, _ = findnz(H) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] = H[i, j] + end + end + end else lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) end return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + cons = cons, cons_j = cons_j, cons_h = cons_h, + cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, cons_jac_prototype = cons_jac_prototype, @@ -116,7 +187,7 @@ function instantiate_function( cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, lag_h, - lag_hess_prototype = f.lag_hess_prototype, + lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index bf1ebb2..1b1335d 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -102,7 +102,8 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, - p = SciMLBase.NullParameters(), num_cons = 0) + p = SciMLBase.NullParameters(), num_cons = 0; + cons_vjp = false, cons_jvp = false) function _f(θ) return f(θ, p)[1] end @@ -118,6 +119,13 @@ function instantiate_function( grad = (G, θ) -> f.grad(G, θ, p) end + if fg == true + function fg!(res, θ) + (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) + return y + end + end + hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing @@ -131,6 +139,13 @@ function instantiate_function( hess = (H, θ) -> f.hess(H, θ, p) end + if fgh == true + function fgh!(G, H, θ) + (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + return y + end + end + if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) hv = function (H, θ, v) @@ -147,10 +162,18 @@ function instantiate_function( f.cons(res, θ, p) end - function cons_oop(x) - _res = zeros(eltype(x), num_cons) - cons(_res, x) - return _res + if adtype.dense_ad isa AutoZygote && Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote") in keys(Base.loaded_modules) + function cons_oop(x) + _res = Zygote.Buffer(x, num_cons) + cons(_res, x) + copy(_res) + end + else + function cons_oop(x) + _res = zeros(eltype(x), num_cons) + cons(_res, x) + return _res + end end end @@ -170,19 +193,37 @@ function instantiate_function( cons_j = (J, θ) -> f.cons_j(J, θ, p) end + if f.cons_vjp === nothing && cons_vjp == true + extras_pullback = prepare_pullback(cons_oop, adtype, x) + function cons_vjp!(J, θ, v) + pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) + end + else + cons_vjp! = nothing + end + + if f.cons_jvp === nothing && cons_jvp == true + extras_pushforward = prepare_pushforward(cons_oop, adtype, x) + function cons_jvp!(J, θ, v) + pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) + end + else + cons_jvp! = nothing + end + conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] - # extras_cons_hess = Vector(undef, length(fncs)) - # for ind in 1:num_cons - # extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) - # end - # conshess_sparsity = getfield.(extras_cons_hess, :sparsity) - # conshess_colors = getfield.(extras_cons_hess, :colors) + extras_cons_hess = Vector(undef, length(fncs)) + for ind in 1:num_cons + extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) + end + conshess_sparsity = getfield.(extras_cons_hess, :sparsity) + conshess_colors = getfield.(extras_cons_hess, :colors) function cons_h(H, θ) for i in 1:num_cons - hessian!(fncs[i], H[i], soadtype, θ) + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end else From e79d4cb9d40e0dd41c5a3d8682fbb52a35e878df Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 5 Aug 2024 18:54:48 -0400 Subject: [PATCH 22/33] Inplace all passing --- ext/OptimizationEnzymeExt.jl | 38 +++-- ext/OptimizationZygoteExt.jl | 296 ++++++++++++++++++++++++++++++++- src/OptimizationDIExt.jl | 30 ++-- src/OptimizationDISparseExt.jl | 31 ++-- test/adtests.jl | 77 +++++---- 5 files changed, 380 insertions(+), 92 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 902a2e7..37dd98b 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -191,27 +191,40 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if cons !== nothing && f.cons_j === nothing - seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) - Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + if num_cons > length(x) + seeds = Enzyme.onehot(x) + Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + else + seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) + Jaccache = Tuple(zero(x) for i in 1:num_cons) + end + y = zeros(eltype(x), num_cons) function cons_j(J, θ) - for i in 1:length(θ) + for i in eachindex(Jaccache) Enzyme.make_zero!(Jaccache[i]) end Enzyme.make_zero!(y) if num_cons > length(θ) Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) + for i in eachindex(θ) + if J isa Vector + J[i] = Jaccache[i][1] + else + copyto!(@view(J[:, i]), Jaccache[i]) + end + end else Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds), BatchDuplicated(θ, Jaccache), Const(p)) - end - for i in 1:length(θ) - if J isa Vector - J[i] = Jaccache[i][1] - else - copyto!(@view(J[:, i]), Jaccache[i]) + for i in 1:num_cons + if J isa Vector + J .= Jaccache[1] + else + copyto!(@view(J[i, :]), Jaccache[i]) + end end end end @@ -261,14 +274,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) function cons_h(res, θ) - Enzyme.make_zero!(cons_bθ) - Enzyme.make_zero!.(cons_vdbθ) - for i in 1:num_cons + Enzyme.make_zero!(cons_bθ) + Enzyme.make_zero!.(cons_vdbθ) Enzyme.autodiff(Enzyme.Forward, cons_f2, Enzyme.BatchDuplicated(θ, cons_vdθ), - Enzyme.BatchDuplicated(bθ, cons_vdbθ), + Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ), Const(f.cons), Const(p), Const(num_cons), diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 2cf6ac1..0f78c30 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -1,5 +1,299 @@ module OptimizationZygoteExt -using DifferentiationInterface, Zygote +using OptimizationBase +import OptimizationBase.ArrayInterface +import OptimizationBase.SciMLBase: OptimizationFunction +import OptimizationBase.LinearAlgebra: I +import DifferentiationInterface +import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, + prepare_jacobian, + gradient!, hessian!, hvp!, jacobian!, gradient, hessian, + hvp, jacobian +using ADTypes, SciMLBase +import Zygote + + +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, x, adtype::ADTypes.AutoZygote, + p = SciMLBase.NullParameters(), num_cons = 0; + fg = false, fgh = false, cons_vjp = false, cons_jvp = false) + function _f(θ) + return f(θ, p)[1] + end + + adtype, soadtype = OptimizationBase.generate_adtype(adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(res, θ) + gradient!(_f, res, adtype, θ, extras_grad) + end + else + grad = (G, θ) -> f.grad(G, θ, p) + end + + if fg == true + function fg!(res, θ) + (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) + return y + end + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) + function hess(res, θ) + hessian!(_f, res, soadtype, θ, extras_hess) + end + else + hess = (H, θ) -> f.hess(H, θ, p) + end + + if fgh == true + function fgh!(G, H, θ) + (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + return y + end + end + + if f.hv === nothing + extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) + hv = function (H, θ, v) + hvp!(_f, H, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + function cons(res, θ) + return f.cons(res, θ, p) + end + + function cons_oop(x) + _res = Zygote.Buffer(x, num_cons) + cons(_res, x) + return copy(_res) + end + + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons_oop, adtype, x) + function cons_j(J, θ) + jacobian!(cons_oop, J, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + if f.cons_vjp === nothing && cons_vjp == true + extras_pullback = prepare_pullback(cons_oop, adtype, x) + function cons_vjp!(J, θ, v) + pullback!(cons_oop, J, adtype, θ, v, extras_pullback) + end + else + cons_vjp! = nothing + end + + if f.cons_jvp === nothing && cons_jvp == true + extras_pushforward = prepare_pushforward(cons_oop, adtype, x) + function cons_jvp!(J, θ, v) + pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward) + end + else + cons_jvp! = nothing + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(H, θ) + for i in 1:num_cons + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) + end + + lag_hess_prototype = f.lag_hess_prototype + + if f.lag_h === nothing + lag_extras = prepare_hessian(lagrangian, soadtype, x) + lag_hess_prototype = zeros(Bool, length(x), length(x)) + + function lag_h(H::AbstractMatrix, θ, σ, λ) + if σ == zero(eltype(θ)) + cons_h(H, θ) + H *= λ + else + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + end + end + + function lag_h(h, θ, σ, λ) + H = eltype(θ).(lag_hess_prototype) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + k = 0 + rows, cols, _ = findnz(H) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] = H[i, j] + end + end + end + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, + lag_hess_prototype = lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) +end + +function OptimizationBase.instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AutoZygote, num_cons = 0) + x = cache.u0 + p = cache.p + + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) +end + +function OptimizationBase.instantiate_function( + f::OptimizationFunction{false}, x, adtype::ADTypes.AutoZygote, + p = SciMLBase.NullParameters(), num_cons = 0) + function _f(θ) + return f(θ, p)[1] + end + + adtype, soadtype = OptimizationBase.generate_adtype(adtype) + + if f.grad === nothing + extras_grad = prepare_gradient(_f, adtype, x) + function grad(θ) + gradient(_f, adtype, θ, extras_grad) + end + else + grad = (θ) -> f.grad(θ, p) + end + + hess_sparsity = f.hess_prototype + hess_colors = f.hess_colorvec + if f.hess === nothing + extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + function hess(θ) + hessian(_f, soadtype, θ, extras_hess) + end + else + hess = (θ) -> f.hess(θ, p) + end + + if f.hv === nothing + extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) + function hv(θ, v) + hvp(_f, soadtype, θ, v, extras_hvp) + end + else + hv = f.hv + end + + if f.cons === nothing + cons = nothing + else + function cons(θ) + return f.cons(θ, p) + end + end + + cons_jac_prototype = f.cons_jac_prototype + cons_jac_colorvec = f.cons_jac_colorvec + if cons !== nothing && f.cons_j === nothing + extras_jac = prepare_jacobian(cons, adtype, x) + cons_j = function (θ) + J = jacobian(cons, adtype, θ, extras_jac) + if size(J, 1) == 1 + J = vec(J) + end + return J + end + else + cons_j = (θ) -> f.cons_j(θ, p) + end + + conshess_sparsity = f.cons_hess_prototype + conshess_colors = f.cons_hess_colorvec + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons(x)[i] for i in 1:num_cons] + extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) + + function cons_h(θ) + H = map(1:num_cons) do i + hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) + end + return H + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + end + + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + hess_prototype = hess_sparsity, + hess_colorvec = hess_colors, + cons_jac_prototype = cons_jac_prototype, + cons_jac_colorvec = cons_jac_colorvec, + cons_hess_prototype = conshess_sparsity, + cons_hess_colorvec = conshess_colors, + lag_h, + lag_hess_prototype = f.lag_hess_prototype, + sys = f.sys, + expr = f.expr, + cons_expr = f.cons_expr) +end + +function OptimizationBase.instantiate_function( + f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AutoZygote, num_cons = 0) + x = cache.u0 + p = cache.p + + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) +end + end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index a975ccc..dbacfaa 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -39,7 +39,7 @@ function instantiate_function( grad = (G, θ) -> f.grad(G, θ, p) end - if fg == true + if fg == true && f.fg === nothing function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) return y @@ -57,7 +57,7 @@ function instantiate_function( hess = (H, θ) -> f.hess(H, θ, p) end - if fgh == true + if fgh == true && f.fgh !== nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y @@ -80,18 +80,14 @@ function instantiate_function( return f.cons(res, θ, p) end - if adtype isa AutoZygote && Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote") in keys(Base.loaded_modules) - function cons_oop(x) - _res = Zygote.Buffer(x, num_cons) - cons(_res, x) - copy(_res) - end - else - function cons_oop(x) - _res = zeros(eltype(x), num_cons) - cons(_res, x) - return _res - end + function cons_oop(x) + _res = zeros(eltype(x), num_cons) + cons(_res, x) + return _res + end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) end end @@ -142,13 +138,9 @@ function instantiate_function( cons_h = (res, θ) -> f.cons_h(res, θ, p) end - function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) - return σ * _f(x) + dot(λ, cons_oop(x)) - end - lag_hess_prototype = f.lag_hess_prototype - if f.lag_h === nothing + if cons !== nothing && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = zeros(Bool, length(x), length(x)) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 1b1335d..2599914 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -103,6 +103,7 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0; + fg = false, fgh = false, cons_vjp = false, cons_jvp = false) function _f(θ) return f(θ, p)[1] @@ -119,7 +120,7 @@ function instantiate_function( grad = (G, θ) -> f.grad(G, θ, p) end - if fg == true + if fg == true && f.fg !== nothing function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) return y @@ -139,7 +140,7 @@ function instantiate_function( hess = (H, θ) -> f.hess(H, θ, p) end - if fgh == true + if fgh == true && f.fgh !== nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y @@ -162,18 +163,14 @@ function instantiate_function( f.cons(res, θ, p) end - if adtype.dense_ad isa AutoZygote && Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote") in keys(Base.loaded_modules) - function cons_oop(x) - _res = Zygote.Buffer(x, num_cons) - cons(_res, x) - copy(_res) - end - else - function cons_oop(x) - _res = zeros(eltype(x), num_cons) - cons(_res, x) - return _res - end + function cons_oop(x) + _res = zeros(eltype(x), num_cons) + cons(_res, x) + return _res + end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) end end @@ -230,12 +227,8 @@ function instantiate_function( cons_h = (res, θ) -> f.cons_h(res, θ, p) end - function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) - return σ * _f(x) + dot(λ, cons_oop(x)) - end - lag_hess_prototype = f.lag_hess_prototype - if f.lag_h === nothing + if cons !== nothing && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = lag_extras.sparsity diff --git a/test/adtests.jl b/test/adtests.jl index c743bbf..196ff34 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -72,46 +72,43 @@ optprob.cons_h(H3, x0) G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2) -if VERSION >= v"1.9" - optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoEnzyme(), cons = cons) - optprob = OptimizationBase.instantiate_function( - optf, x0, OptimizationBase.AutoEnzyme(), - nothing, 1) - optprob.grad(G2, x0) - @test G1 == G2 - optprob.hess(H2, x0) - @test H1 == H2 - res = Array{Float64}(undef, 1) - optprob.cons(res, x0) - @test res == [0.0] - J = Array{Float64}(undef, 2) - optprob.cons_j(J, [5.0, 3.0]) - @test J == [10.0, 6.0] - H3 = [Array{Float64}(undef, 2, 2)] - optprob.cons_h(H3, x0) - @test H3 == [[2.0 0.0; 0.0 2.0]] - - G2 = Array{Float64}(undef, 2) - H2 = Array{Float64}(undef, 2, 2) - - optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoEnzyme(), cons = con2_c) - optprob = OptimizationBase.instantiate_function( - optf, x0, OptimizationBase.AutoEnzyme(), - nothing, 2) - optprob.grad(G2, x0) - @test G1 == G2 - optprob.hess(H2, x0) - @test H1 == H2 - res = Array{Float64}(undef, 2) - optprob.cons(res, x0) - @test res == [0.0, 0.0] - J = Array{Float64}(undef, 2, 2) - optprob.cons_j(J, [5.0, 3.0]) - @test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) - H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] - optprob.cons_h(H3, x0) - @test H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] -end +optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoEnzyme(), cons = cons) +optprob = OptimizationBase.instantiate_function( + optf, x0, OptimizationBase.AutoEnzyme(), + nothing, 1) +optprob.grad(G2, x0) +@test G1 == G2 +optprob.hess(H2, x0) +@test H1 == H2 +res = Array{Float64}(undef, 1) +optprob.cons(res, x0) +@test res == [0.0] +J = Array{Float64}(undef, 2) +optprob.cons_j(J, [5.0, 3.0]) +@test J == [10.0, 6.0] +H3 = [Array{Float64}(undef, 2, 2)] +optprob.cons_h(H3, x0) +@test H3 == [[2.0 0.0; 0.0 2.0]] +G2 = Array{Float64}(undef, 2) +H2 = Array{Float64}(undef, 2, 2) +optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoEnzyme(), cons = con2_c) +optprob = OptimizationBase.instantiate_function( + optf, x0, OptimizationBase.AutoEnzyme(), + nothing, 2) +optprob.grad(G2, x0) +@test G1 == G2 +optprob.hess(H2, x0) +@test H1 == H2 +res = Array{Float64}(undef, 2) +optprob.cons(res, x0) +@test res == [0.0, 0.0] +J = Array{Float64}(undef, 2, 2) +optprob.cons_j(J, [5.0, 3.0]) +@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] +optprob.cons_h(H3, x0) +@test H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] + G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2) From fe1b709c7cf12c762b2aa711c472506d3f86c347 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 5 Aug 2024 22:01:55 -0400 Subject: [PATCH 23/33] Oop improvements --- ext/OptimizationEnzymeExt.jl | 198 ++++++++++++++++++++++++++++------- ext/OptimizationZygoteExt.jl | 153 +++++++++++++++++++-------- src/OptimizationDIExt.jl | 74 +++++++++++-- src/augmented_lagrangian.jl | 2 +- test/adtests.jl | 1 - 5 files changed, 338 insertions(+), 90 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 37dd98b..08242be 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -89,7 +89,7 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false,) + adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false) if f.grad === nothing function grad(res, θ) Enzyme.make_zero!(res) @@ -198,7 +198,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) Jaccache = Tuple(zero(x) for i in 1:num_cons) end - + y = zeros(eltype(x), num_cons) function cons_j(J, θ) @@ -243,7 +243,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Const, Duplicated(cons_res, v), Duplicated(θ, res), - Const(p), + Const(p) ) end else @@ -261,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, f.cons, Duplicated(cons_res, res), Duplicated(θ, v), - Const(p), + Const(p) ) end else @@ -325,7 +325,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, for i in eachindex(θ) vec_lagv = lag_vdbθ[i] - h[k+1:k+i] .= @view(vec_lagv[1:i]) + h[(k + 1):(k + i)] .= @view(vec_lagv[1:i]) k += i end end @@ -356,8 +356,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoEnzyme, p, - num_cons = 0) + adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false) if f.grad === nothing res = zeros(eltype(x), size(x)) function grad(θ) @@ -375,12 +374,31 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x grad = (θ) -> f.grad(θ, p) end + if fg == true + res_fg = zeros(eltype(x), size(x)) + function fg!(θ) + Enzyme.make_zero!(res_fg) + y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, + Const(firstapply), + Active, + Const(f.f), + Enzyme.Duplicated(θ, res_fg), + Const(p) + )[2] + return y, res + end + else + fg! = nothing + end + if f.hess === nothing - function hess(θ) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) + vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + bθ = zeros(eltype(x), length(x)) + vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) + function hess(θ) + Enzyme.make_zero!(bθ) + Enzyme.make_zero!.(vdbθ) Enzyme.autodiff(Enzyme.Forward, inner_grad, @@ -397,9 +415,37 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x hess = (θ) -> f.hess(θ, p) end + if fgh == true + vdθ_fgh = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + vdbθ_fgh = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + G_fgh = zeros(eltype(x), length(x)) + H_fgh = zeros(eltype(x), length(x), length(x)) + + function fgh!(θ) + Enzyme.make_zero!(G_fgh) + Enzyme.make_zero!(H_fgh) + Enzyme.make_zero!.(vdbθ_fgh) + + Enzyme.autodiff(Enzyme.Forward, + inner_grad, + Enzyme.BatchDuplicated(θ, vdθ_fgh), + Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh), + Const(f.f), + Const(p) + ) + + for i in eachindex(θ) + H_fgh[i, :] .= vdbθ_fgh[i] + end + return G_fgh, H_fgh + end + else + fgh! = nothing + end + if f.hv === nothing function hv(θ, v) - Enzyme.autodiff( + return Enzyme.autodiff( Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), Const(_f), Const(f.f), Const(p) )[1] @@ -411,60 +457,136 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x if f.cons === nothing cons = nothing else - cons_oop = (θ) -> f.cons(θ, p) + function cons(θ) + return f.cons(θ, p) + end end - if f.cons !== nothing && f.cons_j === nothing - seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + if cons !== nothing && f.cons_j === nothing + seeds = Enzyme.onehot(x) + Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + function cons_j(θ) - J = Enzyme.autodiff( - Enzyme.Forward, f.cons, BatchDuplicated(θ, seeds), Const(p))[1] - if num_cons == 1 - return reduce(vcat, J) + for i in eachindex(Jaccache) + Enzyme.make_zero!(Jaccache[i]) + end + y, Jaccache = Enzyme.autodiff(Enzyme.Forward, f.cons, Duplicated, + BatchDuplicated(θ, seeds), Const(p)) + if size(y, 1) == 1 + return reduce(vcat, Jaccache) else - return reduce(hcat, J) + return reduce(hcat, Jaccache) end end else cons_j = (θ) -> f.cons_j(θ, p) end - if f.cons !== nothing && f.cons_h === nothing + if cons !== nothing && f.cons_vjp == true + res_vjp = zeros(eltype(x), size(x)) + cons_vjp_res = zeros(eltype(x), num_cons) + + function cons_vjp(θ, v) + Enzyme.make_zero!(res_vjp) + Enzyme.make_zero!(cons_vjp_res) + + Enzyme.autodiff(Enzyme.Reverse, + f.cons, + Const, + Duplicated(cons_vjp_res, v), + Duplicated(θ, res_vjp), + Const(p) + ) + return res_vjp + end + else + cons_vjp = (θ, σ) -> f.cons_vjp(θ, σ, p) + end + + if cons !== nothing && f.cons_jvp == true + res_jvp = zeros(eltype(x), size(x)) + cons_jvp_res = zeros(eltype(x), num_cons) + + function cons_jvp(θ, v) + Enzyme.make_zero!(res_jvp) + Enzyme.make_zero!(cons_jvp_res) + + Enzyme.autodiff(Enzyme.Forward, + f.cons, + Duplicated(cons_jvp_res, res_jvp), + Duplicated(θ, v), + Const(p) + ) + return res_jvp + end + else + cons_jvp = nothing + end + + if cons !== nothing && f.cons_h === nothing + cons_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + cons_bθ = zeros(eltype(x), length(x)) + cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + function cons_h(θ) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) - bθ = zeros(eltype(θ), length(θ)) - vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) - res = [zeros(eltype(x), length(θ), length(θ)) for i in 1:num_cons] - for i in 1:num_cons - Enzyme.make_zero!(bθ) - for el in vdbθ - Enzyme.make_zero!(el) - end + return map(1:num_cons) do i + Enzyme.make_zero!(cons_bθ) + Enzyme.make_zero!.(cons_vdbθ) Enzyme.autodiff(Enzyme.Forward, cons_f2_oop, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), + Enzyme.BatchDuplicated(θ, cons_vdθ), + Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ), Const(f.cons), Const(p), Const(i)) - for j in eachindex(θ) - res[i][j, :] = vdbθ[j] - end + + return reduce(hcat, cons_vdbθ) end - return res end else cons_h = (θ) -> f.cons_h(θ, p) end if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + lag_bθ = zeros(eltype(x), length(x)) + if f.hess_prototype === nothing + lag_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + else + lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype))) + end + + function lag_h(θ, σ, μ) + Enzyme.make_zero!(lag_bθ) + Enzyme.make_zero!.(lag_vdbθ) + + Enzyme.autodiff(Enzyme.Forward, + lag_grad, + Enzyme.BatchDuplicated(θ, lag_vdθ), + Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), + Const(lagrangian), + Const(f.f), + Const(f.cons), + Const(p), + Const(σ), + Const(μ) + ) + + k = 0 + + for i in eachindex(θ) + vec_lagv = lag_vdbθ[i] + res[(k + 1):(k + i), :] .= @view(vec_lagv[1:i]) + k += i + end + return res + end else lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons_oop, cons_j = cons_j, cons_h = cons_h, + cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, cons_hess_prototype = f.cons_hess_prototype, diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 0f78c30..ab43ea5 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -1,6 +1,7 @@ module OptimizationZygoteExt using OptimizationBase +using OptimizationBase.FastClosures import OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I @@ -12,7 +13,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, using ADTypes, SciMLBase import Zygote - function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoZygote, p = SciMLBase.NullParameters(), num_cons = 0; @@ -78,7 +78,6 @@ function OptimizationBase.instantiate_function( cons(_res, x) return copy(_res) end - end cons_jac_prototype = f.cons_jac_prototype @@ -164,7 +163,7 @@ function OptimizationBase.instantiate_function( end return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + cons = cons, cons_j = cons_j, cons_h = cons_h, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -188,39 +187,57 @@ function OptimizationBase.instantiate_function( return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end -function OptimizationBase.instantiate_function( - f::OptimizationFunction{false}, x, adtype::ADTypes.AutoZygote, - p = SciMLBase.NullParameters(), num_cons = 0) +function instantiate_function( + f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote}, + p = SciMLBase.NullParameters(), num_cons = 0; + fg = false, fgh = false, + cons_vjp = false, cons_jvp = false) function _f(θ) return f(θ, p)[1] end - adtype, soadtype = OptimizationBase.generate_adtype(adtype) + adtype, soadtype = generate_sparse_adtype(adtype) if f.grad === nothing - extras_grad = prepare_gradient(_f, adtype, x) - function grad(θ) - gradient(_f, adtype, θ, extras_grad) + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) + function grad(res, θ) + gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end else - grad = (θ) -> f.grad(θ, p) + grad = (G, θ) -> f.grad(G, θ, p) + end + + if fg == true && f.fg !== nothing + function fg!(res, θ) + (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) + return y + end end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better - function hess(θ) - hessian(_f, soadtype, θ, extras_hess) + function hess(res, θ) + hessian!(_f, res, soadtype, θ, extras_hess) end + hess_sparsity = extras_hess.sparsity + hess_colors = extras_hess.colors else - hess = (θ) -> f.hess(θ, p) + hess = (H, θ) -> f.hess(H, θ, p) + end + + if fgh == true && f.fgh !== nothing + function fgh!(G, H, θ) + (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + return y + end end if f.hv === nothing - extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) - function hv(θ, v) - hvp(_f, soadtype, θ, v, extras_hvp) + extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) + hv = function (H, θ, v) + hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end else hv = f.hv @@ -229,48 +246,103 @@ function OptimizationBase.instantiate_function( if f.cons === nothing cons = nothing else - function cons(θ) - return f.cons(θ, p) + function cons(res, θ) + f.cons(res, θ, p) + end + + function cons_oop(x) + _res = Zygote.Buffer(x, num_cons) + cons(_res, x) + return copy(_res) + end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && f.cons_j === nothing - extras_jac = prepare_jacobian(cons, adtype, x) - cons_j = function (θ) - J = jacobian(cons, adtype, θ, extras_jac) + extras_jac = prepare_jacobian(cons_oop, adtype, x) + function cons_j(J, θ) + jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end - return J end + cons_jac_prototype = extras_jac.sparsity + cons_jac_colorvec = extras_jac.colors else - cons_j = (θ) -> f.cons_j(θ, p) + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + if f.cons_vjp === nothing && cons_vjp == true + extras_pullback = prepare_pullback(cons_oop, adtype, x) + function cons_vjp!(J, θ, v) + pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) + end + else + cons_vjp! = nothing + end + + if f.cons_jvp === nothing && cons_jvp == true + extras_pushforward = prepare_pushforward(cons_oop, adtype, x) + function cons_jvp!(J, θ, v) + pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) + end + else + cons_jvp! = nothing end conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing - fncs = [(x) -> cons(x)[i] for i in 1:num_cons] - extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - - function cons_h(θ) - H = map(1:num_cons) do i - hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) + fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] + extras_cons_hess = Vector(undef, length(fncs)) + for ind in 1:num_cons + extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) + end + conshess_sparsity = getfield.(extras_cons_hess, :sparsity) + conshess_colors = getfield.(extras_cons_hess, :colors) + function cons_h(H, θ) + for i in 1:num_cons + hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end - return H end else cons_h = (res, θ) -> f.cons_h(res, θ, p) end - if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_hess_prototype = f.lag_hess_prototype + if cons !== nothing && f.lag_h === nothing + lag_extras = prepare_hessian(lagrangian, soadtype, x) + lag_hess_prototype = lag_extras.sparsity + + function lag_h(H::AbstractMatrix, θ, σ, λ) + if σ == zero(eltype(θ)) + cons_h(H, θ) + H *= λ + else + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + end + end + + function lag_h(h, θ, σ, λ) + H = eltype(θ).(lag_hess_prototype) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + k = 0 + rows, cols, _ = findnz(H) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] = H[i, j] + end + end + end else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = hess_sparsity, @@ -280,20 +352,19 @@ function OptimizationBase.instantiate_function( cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, lag_h, - lag_hess_prototype = f.lag_hess_prototype, + lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) end -function OptimizationBase.instantiate_function( - f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoZygote, num_cons = 0) +function instantiate_function( + f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, + adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0) x = cache.u0 p = cache.p - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) + return instantiate_function(f, x, adtype, p, num_cons) end - end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index dbacfaa..0ab34a2 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -170,7 +170,7 @@ function instantiate_function( end return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + cons = cons, cons_j = cons_j, cons_h = cons_h, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -196,7 +196,8 @@ end function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, - p = SciMLBase.NullParameters(), num_cons = 0) + p = SciMLBase.NullParameters(), num_cons = 0; + fg = false, fgh = false, cons_vjp = false, cons_jvp = false) function _f(θ) return f(θ, p)[1] end @@ -212,10 +213,18 @@ function instantiate_function( grad = (θ) -> f.grad(θ, p) end + if fg == true && f.fg === nothing + function fg!(θ) + res = zeros(eltype(x), size(x)) + (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) + return y, res + end + end + hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if f.hess === nothing - extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better + extras_hess = prepare_hessian(_f, soadtype, x) function hess(θ) hessian(_f, soadtype, θ, extras_hess) end @@ -223,6 +232,15 @@ function instantiate_function( hess = (θ) -> f.hess(θ, p) end + if fgh == true && f.fgh !== nothing + function fgh!(θ) + G = zeros(eltype(x), size(x)) + H = zeros(eltype(x), size(x, 1), size(x, 1)) + (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + return y, G, H + end + end + if f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) function hv(θ, v) @@ -238,6 +256,10 @@ function instantiate_function( function cons(θ) return f.cons(θ, p) end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons(x)) + end end cons_jac_prototype = f.cons_jac_prototype @@ -255,6 +277,24 @@ function instantiate_function( cons_j = (θ) -> f.cons_j(θ, p) end + if f.cons_vjp === nothing && cons_vjp == true + extras_pullback = prepare_pullback(cons, adtype, x) + function cons_vjp!(θ, v) + return pullback(cons, adtype, θ, v, extras_pullback) + end + else + cons_vjp! = nothing + end + + if f.cons_jvp === nothing && cons_jvp == true + extras_pushforward = prepare_pushforward(cons, adtype, x) + function cons_jvp!(θ, v) + return pushforward(cons, adtype, θ, v, extras_pushforward) + end + else + cons_jvp! = nothing + end + conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing @@ -268,25 +308,41 @@ function instantiate_function( return H end else - cons_h = (res, θ) -> f.cons_h(res, θ, p) + cons_h = (θ) -> f.cons_h(θ, p) end - if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_hess_prototype = f.lag_hess_prototype + + if cons !== nothing && f.lag_h === nothing + lag_extras = prepare_hessian(lagrangian, soadtype, x) + lag_hess_prototype = zeros(Bool, length(x), length(x)) + + function lag_h(θ, σ, λ) + if σ == zero(eltype(θ)) + H = cons_h(θ) + for i in 1:num_cons + H[i] *= λ[i] + end + return H + else + return hessian(x -> lagrangian(x, σ, λ), soadtype, θ, lag_extras) + end + end else lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_h = cons_h, + cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, cons_jac_prototype = cons_jac_prototype, cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, - lag_hess_prototype = f.lag_hess_prototype, + lag_h = lag_h, + lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) diff --git a/src/augmented_lagrangian.jl b/src/augmented_lagrangian.jl index 88bca00..8790900 100644 --- a/src/augmented_lagrangian.jl +++ b/src/augmented_lagrangian.jl @@ -10,4 +10,4 @@ function generate_auglag(θ) end return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2) -end \ No newline at end of file +end diff --git a/test/adtests.jl b/test/adtests.jl index 196ff34..f93814a 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -109,7 +109,6 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) @test H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] - G2 = Array{Float64}(undef, 2) H2 = Array{Float64}(undef, 2, 2) From 142f8acc53c2df6d54e8e9aa0de85b99574b4b63 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 7 Aug 2024 10:04:23 -0400 Subject: [PATCH 24/33] boolean switches for oracles --- ext/OptimizationZygoteExt.jl | 22 +++++++++++----------- src/OptimizationDIExt.jl | 5 +++-- src/OptimizationDISparseExt.jl | 15 ++++++++------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index ab43ea5..15aff37 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -1,10 +1,10 @@ module OptimizationZygoteExt -using OptimizationBase +using OptimizationBase, SparseArrays using OptimizationBase.FastClosures import OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction -import OptimizationBase.LinearAlgebra: I +import OptimizationBase.LinearAlgebra: I, dot import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, @@ -187,16 +187,16 @@ function OptimizationBase.instantiate_function( return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end -function instantiate_function( +function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote}, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, + fg = false, fgh = false, conshess = false, cons_vjp = false, cons_jvp = false) function _f(θ) - return f(θ, p)[1] + return f.f(θ, p)[1] end - adtype, soadtype = generate_sparse_adtype(adtype) + adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype) if f.grad === nothing extras_grad = prepare_gradient(_f, adtype.dense_ad, x) @@ -252,7 +252,7 @@ function instantiate_function( function cons_oop(x) _res = Zygote.Buffer(x, num_cons) - cons(_res, x) + f.cons(_res, x, p) return copy(_res) end @@ -297,7 +297,7 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && f.cons_h === nothing && conshess == true fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = Vector(undef, length(fncs)) for ind in 1:num_cons @@ -330,7 +330,7 @@ function instantiate_function( function lag_h(h, θ, σ, λ) H = eltype(θ).(lag_hess_prototype) - hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + hessian!((x) -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) k = 0 rows, cols, _ = findnz(H) for (i, j) in zip(rows, cols) @@ -358,13 +358,13 @@ function instantiate_function( cons_expr = f.cons_expr) end -function instantiate_function( +function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) end end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 0ab34a2..f767b83 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -23,7 +23,8 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, cons_vjp = false, cons_jvp = false) + fg = false, fgh = false, conshess = false, + cons_vjp = false, cons_jvp = false) function _f(θ) return f(θ, p)[1] end @@ -125,7 +126,7 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && f.cons_h === nothing && conshess == true fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 2599914..189efd9 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -103,14 +103,15 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, + objhess = false, + fg = false, fgh = false, conshess = false, cons_vjp = false, cons_jvp = false) function _f(θ) - return f(θ, p)[1] + return f.f(θ, p)[1] end adtype, soadtype = generate_sparse_adtype(adtype) - + @show adtype if f.grad === nothing extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(res, θ) @@ -129,7 +130,7 @@ function instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing + if f.hess === nothing && objhess == true extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) @@ -163,9 +164,9 @@ function instantiate_function( f.cons(res, θ, p) end - function cons_oop(x) + function cons_oop(x, p=p) _res = zeros(eltype(x), num_cons) - cons(_res, x) + f.cons(_res, x, p) return _res end @@ -210,7 +211,7 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && f.cons_h === nothing && conshess == true fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = Vector(undef, length(fncs)) for ind in 1:num_cons From 8fb865d1c73a271972672c9c15e046b392b83aa6 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 15 Aug 2024 23:36:38 -0400 Subject: [PATCH 25/33] bool switch everywhere --- ext/OptimizationEnzymeExt.jl | 164 +++++++++++++++++++---------- ext/OptimizationMTKExt.jl | 35 ++++--- src/OptimizationDIExt.jl | 161 +++++++++++++++++++---------- src/OptimizationDISparseExt.jl | 182 +++++++++++++++++++++++++-------- 4 files changed, 377 insertions(+), 165 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 08242be..d68dca6 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -89,8 +89,12 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, end function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, - adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false) - if f.grad === nothing + adtype::AutoEnzyme, p, num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) + + if g == true && f.grad === nothing function grad(res, θ) Enzyme.make_zero!(res) Enzyme.autodiff(Enzyme.Reverse, @@ -101,11 +105,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Const(p) ) end - else + elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + else + grad = nothing end - if fg == true + if fg == true && f.fg === nothing function fg!(res, θ) Enzyme.make_zero!(res) y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, @@ -117,11 +123,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, )[2] return y end + elseif fg == true + fg! = (res, θ) -> f.fg(res, θ, p) else fg! = nothing end - if f.hess === nothing + if h == true && f.hess === nothing vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) bθ = zeros(eltype(x), length(x)) @@ -148,11 +156,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, res[i, :] .= vdbθ[i] end end - else + elseif h == true hess = (H, θ) -> f.hess(H, θ, p) + else + hess = nothing end - if fgh == true + if fgh == true && f.fgh === nothing function fgh!(G, H, θ) vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ))))) vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ)) @@ -169,19 +179,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, H[i, :] .= vdbθ[i] end end + elseif fgh == true + fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) else fgh! = nothing end - if f.hv === nothing - function hv(H, θ, v) + if hv == true && f.hv === nothing + function hv!(H, θ, v) H .= Enzyme.autodiff( Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), Const(_f), Const(f.f), Const(p) )[1] end + elseif hv == true + hv! = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = f.hv + hv! = nothing end if f.cons === nothing @@ -190,7 +204,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, cons = (res, θ) -> f.cons(res, θ, p) end - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing if num_cons > length(x) seeds = Enzyme.onehot(x) Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) @@ -201,7 +215,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, y = zeros(eltype(x), num_cons) - function cons_j(J, θ) + function cons_j!(J, θ) for i in eachindex(Jaccache) Enzyme.make_zero!(Jaccache[i]) end @@ -228,11 +242,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end end end + elseif cons_j == true && cons !== nothing + cons_j! = (J, θ) -> f.cons_j(J, θ, p) else - cons_j = (J, θ) -> f.cons_j(J, θ, p) + cons_j! = nothing end - if cons !== nothing && f.cons_vjp == true + if cons !== nothing && cons_vjp == true && f.cons_vjp == true cons_res = zeros(eltype(x), num_cons) function cons_vjp!(res, θ, v) Enzyme.make_zero!(res) @@ -246,11 +262,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Const(p) ) end + elseif cons_vjp == true && cons !== nothing + cons_vjp! = (Jv, θ, σ) -> f.cons_vjp(Jv, θ, σ, p) else - cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p) + cons_vjp! = nothing end - if cons !== nothing && f.cons_jvp == true + if cons !== nothing && cons_jvp == true && f.cons_jvp == true cons_res = zeros(eltype(x), num_cons) function cons_jvp!(res, θ, v) @@ -264,16 +282,18 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Const(p) ) end + elseif cons_jvp == true && cons !== nothing + cons_jvp! = (Jv, θ, v) -> f.cons_jvp(Jv, θ, v, p) else - cons_vjp! = nothing + cons_jvp! = nothing end - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && cons_h == true && f.cons_h === nothing cons_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) cons_bθ = zeros(eltype(x), length(x)) cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) - function cons_h(res, θ) + function cons_h!(res, θ) for i in 1:num_cons Enzyme.make_zero!(cons_bθ) Enzyme.make_zero!.(cons_vdbθ) @@ -291,11 +311,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end end end + elseif cons !== nothing && cons_h == true + cons_h! = (res, θ) -> f.cons_h(res, θ, p) else - cons_h = (res, θ) -> f.cons_h(res, θ, p) + cons_h! = nothing end - if f.lag_h === nothing + if lag_h == true && f.lag_h === nothing && cons !== nothing lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) lag_bθ = zeros(eltype(x), length(x)) @@ -306,7 +328,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype))) end - function lag_h(h, θ, σ, μ) + function lag_h!(h, θ, σ, μ) Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) @@ -329,16 +351,22 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, k += i end end + elseif lag_h == true && cons !== nothing + lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{true}(f.f, adtype; + grad = grad, fg = fg!, fgh = fgh!, + hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, + cons_jvp = cons_jvp!, cons_vjp = cons_vjp!, + cons_h = cons_h!, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, + lag_h = lag_h!, lag_hess_prototype = f.lag_hess_prototype, sys = f.sys, expr = f.expr, @@ -356,8 +384,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false) - if f.grad === nothing + adtype::AutoEnzyme, p, num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) + + if g == true && f.grad === nothing res = zeros(eltype(x), size(x)) function grad(θ) Enzyme.make_zero!(res) @@ -370,11 +402,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x ) return res end - else + elseif fg == true grad = (θ) -> f.grad(θ, p) + else + grad = nothing end - if fg == true + if fg == true && f.fg === nothing res_fg = zeros(eltype(x), size(x)) function fg!(θ) Enzyme.make_zero!(res_fg) @@ -387,11 +421,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x )[2] return y, res end + elseif fg == true + fg! = (θ) -> f.fg(θ, p) else fg! = nothing end - if f.hess === nothing + if h == true && f.hess === nothing vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) bθ = zeros(eltype(x), length(x)) vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) @@ -411,11 +447,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return reduce( vcat, [reshape(vdbθ[i], (1, length(vdbθ[i]))) for i in eachindex(θ)]) end - else + elseif h == true hess = (θ) -> f.hess(θ, p) + else + hess = nothing end - if fgh == true + if fgh == true && f.fgh === nothing vdθ_fgh = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) vdbθ_fgh = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) G_fgh = zeros(eltype(x), length(x)) @@ -439,19 +477,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x end return G_fgh, H_fgh end + elseif fgh == true + fgh! = (θ) -> f.fgh(θ, p) else fgh! = nothing end - if f.hv === nothing - function hv(θ, v) + if hv == true && f.hv === nothing + function hv!(θ, v) return Enzyme.autodiff( Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v), Const(_f), Const(f.f), Const(p) )[1] end + elseif hv == true + hv! = (θ, v) -> f.hv(θ, v, p) else - hv = f.hv + hv! = f.hv end if f.cons === nothing @@ -462,11 +504,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x end end - if cons !== nothing && f.cons_j === nothing + if cons_j == true && cons !== nothing && f.cons_j === nothing seeds = Enzyme.onehot(x) Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) - function cons_j(θ) + function cons_j!(θ) for i in eachindex(Jaccache) Enzyme.make_zero!(Jaccache[i]) end @@ -478,15 +520,17 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return reduce(hcat, Jaccache) end end + elseif cons_j == true && cons !== nothing + cons_j! = (θ) -> f.cons_j(θ, p) else - cons_j = (θ) -> f.cons_j(θ, p) + cons_j! = nothing end - if cons !== nothing && f.cons_vjp == true + if cons_vjp == true && cons !== nothing && f.cons_vjp == true res_vjp = zeros(eltype(x), size(x)) cons_vjp_res = zeros(eltype(x), num_cons) - function cons_vjp(θ, v) + function cons_vjp!(θ, v) Enzyme.make_zero!(res_vjp) Enzyme.make_zero!(cons_vjp_res) @@ -499,15 +543,17 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x ) return res_vjp end + elseif cons_vjp == true && cons !== nothing + cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p) else - cons_vjp = (θ, σ) -> f.cons_vjp(θ, σ, p) + cons_vjp! = nothing end - if cons !== nothing && f.cons_jvp == true + if cons_jvp == true && cons !== nothing && f.cons_jvp == true res_jvp = zeros(eltype(x), size(x)) cons_jvp_res = zeros(eltype(x), num_cons) - function cons_jvp(θ, v) + function cons_jvp!(θ, v) Enzyme.make_zero!(res_jvp) Enzyme.make_zero!(cons_jvp_res) @@ -519,16 +565,18 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x ) return res_jvp end + elseif cons_jvp == true && cons !== nothing + cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p) else - cons_jvp = nothing + cons_jvp! = nothing end - if cons !== nothing && f.cons_h === nothing + if cons_h == true && cons !== nothing && f.cons_h === nothing cons_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) cons_bθ = zeros(eltype(x), length(x)) cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) - function cons_h(θ) + function cons_h!(θ) return map(1:num_cons) do i Enzyme.make_zero!(cons_bθ) Enzyme.make_zero!.(cons_vdbθ) @@ -543,11 +591,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return reduce(hcat, cons_vdbθ) end end + elseif cons_h == true && cons !== nothing + cons_h! = (θ) -> f.cons_h(θ, p) else - cons_h = (θ) -> f.cons_h(θ, p) + cons_h! = nothing end - if f.lag_h === nothing + if lag_h == true && f.lag_h === nothing && cons !== nothing lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) lag_bθ = zeros(eltype(x), length(x)) if f.hess_prototype === nothing @@ -556,7 +606,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype))) end - function lag_h(θ, σ, μ) + function lag_h!(θ, σ, μ) Enzyme.make_zero!(lag_bθ) Enzyme.make_zero!.(lag_vdbθ) @@ -581,16 +631,22 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x end return res end + elseif lag_h == true && cons !== nothing + lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{false}(f.f, adtype; grad = grad, + fg = fg!, fgh = fgh!, + hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, + cons_jvp = cons_jvp!, cons_vjp = cons_vjp!, + cons_h = cons_h!, hess_prototype = f.hess_prototype, cons_jac_prototype = f.cons_jac_prototype, cons_hess_prototype = f.cons_hess_prototype, - lag_h = lag_h, + lag_h = lag_h!, lag_hess_prototype = f.lag_hess_prototype, sys = f.sys, expr = f.expr, diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index 7bdda9a..843c722 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -8,7 +8,10 @@ using ModelingToolkit function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics}, p, - num_cons = 0) + num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) p = isnothing(p) ? SciMLBase.NullParameters() : p sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p; @@ -17,8 +20,8 @@ function OptimizationBase.instantiate_function( ucons = fill(0.0, num_cons)))) #sys = ModelingToolkit.structural_simplify(sys) - f = OptimizationProblem(sys, x, p, grad = true, hess = true, - sparse = true, cons_j = true, cons_h = true, + f = OptimizationProblem(sys, x, p, grad = g, hess = h, + sparse = true, cons_j = cons_j, cons_h = cons_h, cons_sparse = true).f grad = (G, θ, args...) -> f.grad(G, θ, p, args...) @@ -54,7 +57,10 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0) + adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0, + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, cache.u0, @@ -64,8 +70,8 @@ function OptimizationBase.instantiate_function( ucons = fill(0.0, num_cons)))) #sys = ModelingToolkit.structural_simplify(sys) - f = OptimizationProblem(sys, cache.u0, cache.p, grad = true, hess = true, - sparse = true, cons_j = true, cons_h = true, + f = OptimizationProblem(sys, cache.u0, cache.p, grad = g, hess = h, + sparse = true, cons_j = cons_j, cons_h = cons_h, cons_sparse = true).f grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) @@ -101,7 +107,9 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p, - num_cons = 0) + num_cons = 0, g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) p = isnothing(p) ? SciMLBase.NullParameters() : p sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p; @@ -110,8 +118,8 @@ function OptimizationBase.instantiate_function( ucons = fill(0.0, num_cons)))) #sys = ModelingToolkit.structural_simplify(sys) - f = OptimizationProblem(sys, x, p, grad = true, hess = true, - sparse = false, cons_j = true, cons_h = true, + f = OptimizationProblem(sys, x, p, grad = g, hess = h, + sparse = false, cons_j = cons_j, cons_h = cons_h, cons_sparse = false).f grad = (G, θ, args...) -> f.grad(G, θ, p, args...) @@ -147,7 +155,10 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::AutoSymbolics, num_cons = 0) + adtype::AutoSymbolics, num_cons = 0, + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, cache.u0, @@ -157,8 +168,8 @@ function OptimizationBase.instantiate_function( ucons = fill(0.0, num_cons)))) #sys = ModelingToolkit.structural_simplify(sys) - f = OptimizationProblem(sys, cache.u0, cache.p, grad = true, hess = true, - sparse = false, cons_j = true, cons_h = true, + f = OptimizationProblem(sys, cache.u0, cache.p, grad = g, hess = h, + sparse = false, cons_j = cons_j, cons_h = cons_h, cons_sparse = false).f grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index f767b83..aac5062 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -23,21 +23,24 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, conshess = false, - cons_vjp = false, cons_jvp = false) + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) function _f(θ) return f(θ, p)[1] end adtype, soadtype = generate_adtype(adtype) - if f.grad === nothing + if g == true && f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) function grad(res, θ) gradient!(_f, res, adtype, θ, extras_grad) end - else + elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + else + grad = nothing end if fg == true && f.fg === nothing @@ -45,33 +48,45 @@ function instantiate_function( (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) return y end + elseif fg == true + fg! = (G, θ) -> f.fg(G, θ, p) + else + fg! = nothing end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing + if h == true && f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end - else + elseif h == true hess = (H, θ) -> f.hess(H, θ, p) + else + hess = nothing end if fgh == true && f.fgh !== nothing function fgh!(G, H, θ) - (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, soadtype, θ, extras_hess) return y end + elseif fgh == true + fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) + else + fgh! = nothing end - if f.hv === nothing + if hv == true && f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) - hv = function (H, θ, v) + function hv!(H, θ, v) hvp!(_f, H, soadtype, θ, v, extras_hvp) end + elseif hv == true + hv = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = f.hv + hv = nothing end if f.cons === nothing @@ -94,58 +109,66 @@ function instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - function cons_j(J, θ) + function cons_j!(J, θ) jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end end + elseif cons_j == true && cons !== nothing + cons_j! = (J, θ) -> f.cons_j(J, θ, p) else - cons_j = (J, θ) -> f.cons_j(J, θ, p) + cons_j! = nothing end - if f.cons_vjp === nothing && cons_vjp == true + if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing extras_pullback = prepare_pullback(cons_oop, adtype, x) function cons_vjp!(J, θ, v) pullback!(cons_oop, J, adtype, θ, v, extras_pullback) end + elseif cons_vjp == true && cons !== nothing + cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) else cons_vjp! = nothing end - if f.cons_jvp === nothing && cons_jvp == true + if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing extras_pushforward = prepare_pushforward(cons_oop, adtype, x) function cons_jvp!(J, θ, v) pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward) end + elseif cons_jvp == true && cons !== nothing + cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) else cons_jvp! = nothing end conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing && conshess == true + if cons !== nothing && f.cons_h === nothing && cons_h == true fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - function cons_h(H, θ) + function cons_h!(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end + elseif cons_h == true && cons !== nothing + cons_h! = (res, θ) -> f.cons_h(res, θ, p) else - cons_h = (res, θ) -> f.cons_h(res, θ, p) + cons_h! = nothing end lag_hess_prototype = f.lag_hess_prototype - if cons !== nothing && f.lag_h === nothing + if cons !== nothing && lag_h == true && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = zeros(Bool, length(x), length(x)) - function lag_h(H::AbstractMatrix, θ, σ, λ) + function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) cons_h(H, θ) H *= λ @@ -154,7 +177,7 @@ function instantiate_function( end end - function lag_h(h, θ, σ, λ) + function lag_h!(h, θ, σ, λ) H = eltype(θ).(lag_hess_prototype) hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) k = 0 @@ -166,12 +189,14 @@ function instantiate_function( end end end + elseif lag_h == true && cons !== nothing + lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -179,7 +204,7 @@ function instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, + lag_h = lag_h!, lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, @@ -188,67 +213,85 @@ end function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AbstractADType, num_cons = 0) + adtype::ADTypes.AbstractADType, num_cons = 0, + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons) + return instantiate_function(f, x, adtype, p, num_cons; g = g, h = h, hv = hv, + fg = fg, fgh = fgh, cons_j = cons_j, cons_vjp = cons_vjp, cons_jvp = cons_jvp, + cons_h = cons_h, lag_h = lag_h) end function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, cons_vjp = false, cons_jvp = false) + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) function _f(θ) return f(θ, p)[1] end adtype, soadtype = generate_adtype(adtype) - if f.grad === nothing + if g == true && f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) function grad(θ) gradient(_f, adtype, θ, extras_grad) end - else + elseif g == true grad = (θ) -> f.grad(θ, p) + else + grad = nothing end if fg == true && f.fg === nothing function fg!(θ) - res = zeros(eltype(x), size(x)) - (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) + (y, res) = value_and_gradient(_f, adtype, θ, extras_grad) return y, res end + elseif fg == true + fg! = (θ) -> f.fg(θ, p) + else + fg! = nothing end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing + if h == true && f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) function hess(θ) hessian(_f, soadtype, θ, extras_hess) end - else + elseif h == true hess = (θ) -> f.hess(θ, p) + else + hess = nothing end if fgh == true && f.fgh !== nothing function fgh!(θ) - G = zeros(eltype(x), size(x)) - H = zeros(eltype(x), size(x, 1), size(x, 1)) - (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) + (y, G, H) = value_derivative_and_second_derivative(_f, adtype, θ, extras_hess) return y, G, H end + elseif fgh == true + fgh! = (θ) -> f.fgh(θ, p) + else + fgh! = nothing end - if f.hv === nothing + if hv == true && f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) - function hv(θ, v) + function hv!(θ, v) hvp(_f, soadtype, θ, v, extras_hvp) end + elseif hv == true + hv! = (θ, v) -> f.hv(θ, v, p) else - hv = f.hv + hv! = nothing end if f.cons === nothing @@ -265,60 +308,68 @@ function instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing extras_jac = prepare_jacobian(cons, adtype, x) - cons_j = function (θ) + function cons_j!(θ) J = jacobian(cons, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end return J end + elseif cons_j == true && cons !== nothing + cons_j! = (θ) -> f.cons_j(θ, p) else - cons_j = (θ) -> f.cons_j(θ, p) + cons_j! = nothing end - if f.cons_vjp === nothing && cons_vjp == true + if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing extras_pullback = prepare_pullback(cons, adtype, x) function cons_vjp!(θ, v) return pullback(cons, adtype, θ, v, extras_pullback) end + elseif cons_vjp == true && cons !== nothing + cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p) else cons_vjp! = nothing end - if f.cons_jvp === nothing && cons_jvp == true + if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing extras_pushforward = prepare_pushforward(cons, adtype, x) function cons_jvp!(θ, v) return pushforward(cons, adtype, θ, v, extras_pushforward) end + elseif cons_jvp == true && cons !== nothing + cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p) else cons_jvp! = nothing end conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && cons_h == true && f.cons_h === nothing fncs = [(x) -> cons(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - function cons_h(θ) + function cons_h!(θ) H = map(1:num_cons) do i hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) end return H end + elseif cons_h == true && cons !== nothing + cons_h! = (θ) -> f.cons_h(θ, p) else - cons_h = (θ) -> f.cons_h(θ, p) + cons_h! = nothing end lag_hess_prototype = f.lag_hess_prototype - if cons !== nothing && f.lag_h === nothing + if cons !== nothing && lag_h == true && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = zeros(Bool, length(x), length(x)) - function lag_h(θ, σ, λ) + function lag_h!(θ, σ, λ) if σ == zero(eltype(θ)) H = cons_h(θ) for i in 1:num_cons @@ -329,12 +380,14 @@ function instantiate_function( return hessian(x -> lagrangian(x, σ, λ), soadtype, θ, lag_extras) end end + elseif lag_h == true && cons !== nothing + lag_h! = (θ, σ, λ) -> f.lag_h(θ, σ, λ, p) else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -342,7 +395,7 @@ function instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h = lag_h, + lag_h = lag_h!, lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 189efd9..3c22b59 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -103,22 +103,24 @@ end function instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, p = SciMLBase.NullParameters(), num_cons = 0; - objhess = false, - fg = false, fgh = false, conshess = false, - cons_vjp = false, cons_jvp = false) + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) function _f(θ) return f.f(θ, p)[1] end adtype, soadtype = generate_sparse_adtype(adtype) - @show adtype - if f.grad === nothing + + if g == true && f.grad === nothing extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(res, θ) gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end - else + elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + else + grad = nothing end if fg == true && f.fg !== nothing @@ -126,19 +128,25 @@ function instantiate_function( (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) return y end + elseif fg == true + fg! = (G, θ) -> f.fg(G, θ, p) + else + fg! = nothing end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing && objhess == true + if f.hess === nothing && h == true extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end hess_sparsity = extras_hess.sparsity hess_colors = extras_hess.colors - else + elseif h == true hess = (H, θ) -> f.hess(H, θ, p) + else + hess = nothing end if fgh == true && f.fgh !== nothing @@ -146,15 +154,21 @@ function instantiate_function( (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y end + elseif fgh == true + fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) + else + fgh! = nothing end - if f.hv === nothing + if hv == true && f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) - hv = function (H, θ, v) + function hv!(H, θ, v) hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end + elseif hv == true + hv = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = f.hv + hv = nothing end if f.cons === nothing @@ -177,9 +191,9 @@ function instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - function cons_j(J, θ) + function cons_j!(J, θ) jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) @@ -187,8 +201,10 @@ function instantiate_function( end cons_jac_prototype = extras_jac.sparsity cons_jac_colorvec = extras_jac.colors + elseif cons_j === true && cons !== nothing + cons_j! = (J, θ) -> f.cons_j(J, θ, p) else - cons_j = (J, θ) -> f.cons_j(J, θ, p) + cons_j! = nothing end if f.cons_vjp === nothing && cons_vjp == true @@ -196,6 +212,8 @@ function instantiate_function( function cons_vjp!(J, θ, v) pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) end + elseif cons_vjp === true && cons !== nothing + cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) else cons_vjp! = nothing end @@ -205,6 +223,8 @@ function instantiate_function( function cons_jvp!(J, θ, v) pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) end + elseif cons_jvp === true && cons !== nothing + cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) else cons_jvp! = nothing end @@ -219,21 +239,25 @@ function instantiate_function( end conshess_sparsity = getfield.(extras_cons_hess, :sparsity) conshess_colors = getfield.(extras_cons_hess, :colors) - function cons_h(H, θ) + function cons_h!(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end + elseif cons_h == true && cons !== nothing + cons_h! = (res, θ) -> f.cons_h(res, θ, p) else - cons_h = (res, θ) -> f.cons_h(res, θ, p) + cons_h! = nothing end lag_hess_prototype = f.lag_hess_prototype - if cons !== nothing && f.lag_h === nothing + lag_hess_colors = f.lag_hess_colorvec + if cons !== nothing && lag_h == true && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = lag_extras.sparsity + lag_hess_colors = lag_extras.colors - function lag_h(H::AbstractMatrix, θ, σ, λ) + function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) cons_h(H, θ) H *= λ @@ -242,7 +266,7 @@ function instantiate_function( end end - function lag_h(h, θ, σ, λ) + function lag_h!(h, θ, σ, λ) H = eltype(θ).(lag_hess_prototype) hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) k = 0 @@ -254,19 +278,23 @@ function instantiate_function( end end end + elseif lag_h == true + lag_h! = (H, θ, σ, λ) -> f.lag_h(H, θ, σ, λ, p) else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, + cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, cons_jac_prototype = cons_jac_prototype, cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, + lag_h = lag_h!, lag_hess_prototype = lag_hess_prototype, + lag_hess_colorvec = lag_hess_colors, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) @@ -274,47 +302,83 @@ end function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0) + adtype::ADTypes.AutoSparse{<:AbstractADType}, num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) x = cache.u0 p = cache.p - return instantiate_function(f, x, adtype, p, num_cons) + return instantiate_function(f, x, adtype, p, num_cons; g = g, h = h, hv = hv, fg = fg, + fgh = fgh, cons_j = cons_j, cons_vjp = cons_vjp, cons_jvp = cons_jvp, cons_h = cons_h, + lag_h = lag_h) end function instantiate_function( f::OptimizationFunction{false}, x, adtype::ADTypes.AutoSparse{<:AbstractADType}, - p = SciMLBase.NullParameters(), num_cons = 0) + p = SciMLBase.NullParameters(), num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) function _f(θ) return f(θ, p)[1] end adtype, soadtype = generate_sparse_adtype(adtype) - if f.grad === nothing + if g == true && f.grad === nothing extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(θ) gradient(_f, adtype.dense_ad, θ, extras_grad) end - else + elseif g == true grad = (θ) -> f.grad(θ, p) + else + grad = nothing + end + + if fg == true && f.fg !== nothing + function fg!(θ) + (y, G) = value_and_gradient(_f, adtype.dense_ad, θ, extras_grad) + return y, G + end + elseif fg == true + fg! = (θ) -> f.fg(θ, p) + else + fg! = nothing + end + + if fgh == true && f.fgh !== nothing + function fgh!(θ) + (y, G, H) = value_derivative_and_second_derivative(_f, soadtype, θ, extras_hess) + return y, G, H + end + elseif fgh == true + fgh! = (θ) -> f.fgh(θ, p) + else + fgh! = nothing end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing + if h == true && f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better function hess(θ) hessian(_f, soadtype, θ, extras_hess) end - else + elseif h == true hess = (θ) -> f.hess(θ, p) + else + hess = nothing end - if f.hv === nothing + if hv == true && f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) - function hv(θ, v) + function hv!(θ, v) hvp(_f, soadtype.dense_ad, θ, v, extras_hvp) end + elseif hv == true + hv! = (θ, v) -> f.hv(θ, v, p) else hv = f.hv end @@ -325,54 +389,82 @@ function instantiate_function( function cons(θ) f.cons(θ, p) end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons(x)) + end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing extras_jac = prepare_jacobian(cons, adtype, x) - function cons_j(θ) + function cons_j!(θ) J = jacobian(cons, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end return J end - else + cons_jac_prototype = extras_jac.sparsity + cons_jac_colorvec = extras_jac.colors + elseif cons_j === true && cons !== nothing cons_j = (θ) -> f.cons_j(θ, p) + else + cons_j = nothing end conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && cons_h == true && f.cons_h === nothing fncs = [(x) -> cons(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - function cons_h(θ) + function cons_h!(θ) H = map(1:num_cons) do i hessian(fncs[i], soadtype, θ, extras_cons_hess[i]) end return H end + conshess_sparsity = getfield.(extras_cons_hess, :sparsity) + conshess_colors = getfield.(extras_cons_hess, :colors) + elseif cons_h == true && cons !== nothing + cons_h! = (res, θ) -> f.cons_h(res, θ, p) else - cons_h = (res, θ) -> f.cons_h(res, θ, p) + cons_h! = nothing end - if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_hess_prototype = f.lag_hess_prototype + lag_hess_colors = f.lag_hess_colorvec + if cons !== nothing && lag_h == true && f.lag_h === nothing + lag_extras = prepare_hessian(lagrangian, soadtype, x) + function lag_h!(θ, σ, λ) + if σ == zero(eltype(θ)) + return λ * cons_h!(θ) + else + hess = hessian(x -> lagrangian(x, σ, λ), soadtype, θ, lag_extras) + return hess + end + end + lag_hess_prototype = lag_extras.sparsity + lag_hess_colors = lag_extras.colors + elseif lag_h == true && cons !== nothing + lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) else - lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, + cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, cons_jac_prototype = cons_jac_prototype, cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, - lag_hess_prototype = f.lag_hess_prototype, + lag_h = lag_h!, + lag_hess_prototype = lag_hess_prototype, + lag_hess_colorvec = lag_hess_colors, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) From 2bad38468fd9a5e17ac50cb4af5677f2b9a47237 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 16 Aug 2024 02:06:59 -0400 Subject: [PATCH 26/33] more of the switches, and in tests too --- ext/OptimizationEnzymeExt.jl | 10 +- ext/OptimizationMTKExt.jl | 2 +- ext/OptimizationZygoteExt.jl | 168 ++++++++++++++++++++++----------- src/OptimizationDIExt.jl | 9 +- src/OptimizationDISparseExt.jl | 36 +++++-- src/function.jl | 6 +- test/adtests.jl | 94 +++++++++--------- 7 files changed, 202 insertions(+), 123 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index d68dca6..ea6dafe 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -93,7 +93,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) - if g == true && f.grad === nothing function grad(res, θ) Enzyme.make_zero!(res) @@ -351,7 +350,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, k += i end end - elseif lag_h == true && cons !== nothing + elseif lag_h == true && cons !== nothing lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) else lag_h! = nothing @@ -384,11 +383,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, end function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, - adtype::AutoEnzyme, p, num_cons = 0; + adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) - if g == true && f.grad === nothing res = zeros(eltype(x), size(x)) function grad(θ) @@ -637,10 +635,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x lag_h! = nothing end - return OptimizationFunction{false}(f.f, adtype; grad = grad, + return OptimizationFunction{false}(f.f, adtype; grad = grad, fg = fg!, fgh = fgh!, hess = hess, hv = hv!, - cons = cons, cons_j = cons_j!, + cons = cons, cons_j = cons_j!, cons_jvp = cons_jvp!, cons_vjp = cons_vjp!, cons_h = cons_h!, hess_prototype = f.hess_prototype, diff --git a/ext/OptimizationMTKExt.jl b/ext/OptimizationMTKExt.jl index 843c722..ff1dce2 100644 --- a/ext/OptimizationMTKExt.jl +++ b/ext/OptimizationMTKExt.jl @@ -169,7 +169,7 @@ function OptimizationBase.instantiate_function( num_cons)))) #sys = ModelingToolkit.structural_simplify(sys) f = OptimizationProblem(sys, cache.u0, cache.p, grad = g, hess = h, - sparse = false, cons_j = cons_j, cons_h = cons_h, + sparse = false, cons_j = cons_j, cons_h = cons_h, cons_sparse = false).f grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 15aff37..a4721ae 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -16,54 +16,70 @@ import Zygote function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoZygote, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, cons_vjp = false, cons_jvp = false) + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, + lag_h = false) function _f(θ) return f(θ, p)[1] end adtype, soadtype = OptimizationBase.generate_adtype(adtype) - if f.grad === nothing + if g == true && f.grad === nothing extras_grad = prepare_gradient(_f, adtype, x) function grad(res, θ) gradient!(_f, res, adtype, θ, extras_grad) end - else + elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + else + grad = nothing end - if fg == true + if fg == true && f.fg === nothing function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) return y end + elseif fg == true + fg! = (G, θ) -> f.fg(G, θ, p) + else + fg! = nothing end hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec - if f.hess === nothing + if h == true && f.hess === nothing extras_hess = prepare_hessian(_f, soadtype, x) function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end - else + elseif h == true hess = (H, θ) -> f.hess(H, θ, p) + else + hess = nothing end - if fgh == true + if fgh == true && f.fgh === nothing function fgh!(G, H, θ) (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y end + elseif fgh == true + fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) + else + fgh! = nothing end - if f.hv === nothing + if hv == true && f.hv === nothing extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x))) - hv = function (H, θ, v) + function hv!(H, θ, v) hvp!(_f, H, soadtype, θ, v, extras_hvp) end + elseif hv == true + hv! = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = f.hv + hv! = nothing end if f.cons === nothing @@ -78,66 +94,74 @@ function OptimizationBase.instantiate_function( cons(_res, x) return copy(_res) end + + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) + end end cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - function cons_j(J, θ) + function cons_j!(J, θ) jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) end end + elseif cons !== nothing && cons_j == true + cons_j! = (J, θ) -> f.cons_j(J, θ, p) else - cons_j = (J, θ) -> f.cons_j(J, θ, p) + cons_j! = nothing end - if f.cons_vjp === nothing && cons_vjp == true + if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing extras_pullback = prepare_pullback(cons_oop, adtype, x) function cons_vjp!(J, θ, v) pullback!(cons_oop, J, adtype, θ, v, extras_pullback) end + elseif cons_vjp == true + cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) else cons_vjp! = nothing end - if f.cons_jvp === nothing && cons_jvp == true + if cons !== nothing && f.cons_jvp === nothing && cons_jvp == true extras_pushforward = prepare_pushforward(cons_oop, adtype, x) function cons_jvp!(J, θ, v) pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward) end + elseif cons_jvp == true + cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) else cons_jvp! = nothing end conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing + if cons !== nothing && cons_h == true && f.cons_h === nothing fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x)) - function cons_h(H, θ) + function cons_h!(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end + elseif cons !== nothing && cons_h == true + cons_h! = (res, θ) -> f.cons_h(res, θ, p) else - cons_h = (res, θ) -> f.cons_h(res, θ, p) - end - - function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) - return σ * _f(x) + dot(λ, cons_oop(x)) + cons_h! = nothing end lag_hess_prototype = f.lag_hess_prototype - if f.lag_h === nothing + if f.lag_h === nothing && cons !== nothing && lag_h == true lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = zeros(Bool, length(x), length(x)) - function lag_h(H::AbstractMatrix, θ, σ, λ) + function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) cons_h(H, θ) H *= λ @@ -146,7 +170,7 @@ function OptimizationBase.instantiate_function( end end - function lag_h(h, θ, σ, λ) + function lag_h!(h, θ, σ, λ) H = eltype(θ).(lag_hess_prototype) hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) k = 0 @@ -158,12 +182,14 @@ function OptimizationBase.instantiate_function( end end end + elseif cons !== nothing && lag_h == true + lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, @@ -171,7 +197,7 @@ function OptimizationBase.instantiate_function( cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, + lag_h = lag_h!, lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, @@ -180,31 +206,36 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoZygote, num_cons = 0) + adtype::ADTypes.AutoZygote, num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false) x = cache.u0 p = cache.p - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function( + f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h) end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote}, p = SciMLBase.NullParameters(), num_cons = 0; - fg = false, fgh = false, conshess = false, - cons_vjp = false, cons_jvp = false) + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false) function _f(θ) return f.f(θ, p)[1] end adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype) - if f.grad === nothing + if g == true && f.grad === nothing extras_grad = prepare_gradient(_f, adtype.dense_ad, x) function grad(res, θ) gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end - else + elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + else + grad = nothing end if fg == true && f.fg !== nothing @@ -212,6 +243,10 @@ function OptimizationBase.instantiate_function( (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) return y end + elseif fg == true + fg! = (G, θ) -> f.fg(G, θ, p) + else + fg! = nothing end hess_sparsity = f.hess_prototype @@ -223,8 +258,10 @@ function OptimizationBase.instantiate_function( end hess_sparsity = extras_hess.sparsity hess_colors = extras_hess.colors - else + elseif h == true hess = (H, θ) -> f.hess(H, θ, p) + else + hess = nothing end if fgh == true && f.fgh !== nothing @@ -232,15 +269,21 @@ function OptimizationBase.instantiate_function( (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess) return y end + elseif fgh == true + fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) + else + fgh! = nothing end - if f.hv === nothing + if hv == true && f.hv !== nothing extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x))) - hv = function (H, θ, v) + function hv!(H, θ, v) hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end + elseif hv == true + hv! = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = f.hv + hv! = nothing end if f.cons === nothing @@ -263,9 +306,9 @@ function OptimizationBase.instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec - if cons !== nothing && f.cons_j === nothing + if cons !== nothing && cons_j == true && f.cons_j === nothing extras_jac = prepare_jacobian(cons_oop, adtype, x) - function cons_j(J, θ) + function cons_j!(J, θ) jacobian!(cons_oop, J, adtype, θ, extras_jac) if size(J, 1) == 1 J = vec(J) @@ -273,31 +316,37 @@ function OptimizationBase.instantiate_function( end cons_jac_prototype = extras_jac.sparsity cons_jac_colorvec = extras_jac.colors + elseif cons !== nothing && cons_j == true + cons_j! = (J, θ) -> f.cons_j(J, θ, p) else - cons_j = (J, θ) -> f.cons_j(J, θ, p) + cons_j! = nothing end - if f.cons_vjp === nothing && cons_vjp == true + if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing extras_pullback = prepare_pullback(cons_oop, adtype, x) function cons_vjp!(J, θ, v) pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) end + elseif cons_vjp == true + cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) else cons_vjp! = nothing end - if f.cons_jvp === nothing && cons_jvp == true + if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing extras_pushforward = prepare_pushforward(cons_oop, adtype, x) function cons_jvp!(J, θ, v) pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) end + elseif cons_jvp == true + cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) else cons_jvp! = nothing end conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing && conshess == true + if cons !== nothing && f.cons_h === nothing && cons_h == true fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = Vector(undef, length(fncs)) for ind in 1:num_cons @@ -305,21 +354,23 @@ function OptimizationBase.instantiate_function( end conshess_sparsity = getfield.(extras_cons_hess, :sparsity) conshess_colors = getfield.(extras_cons_hess, :colors) - function cons_h(H, θ) + function cons_h!(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) end end + elseif cons_h == true + cons_h! = (res, θ) -> f.cons_h(res, θ, p) else - cons_h = (res, θ) -> f.cons_h(res, θ, p) + cons_h! = nothing end lag_hess_prototype = f.lag_hess_prototype - if cons !== nothing && f.lag_h === nothing + if cons !== nothing && cons_h == true && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = lag_extras.sparsity - function lag_h(H::AbstractMatrix, θ, σ, λ) + function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) cons_h(H, θ) H *= λ @@ -328,7 +379,7 @@ function OptimizationBase.instantiate_function( end end - function lag_h(h, θ, σ, λ) + function lag_h!(h, θ, σ, λ) H = eltype(θ).(lag_hess_prototype) hessian!((x) -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) k = 0 @@ -340,18 +391,20 @@ function OptimizationBase.instantiate_function( end end end + elseif cons !== nothing && cons_h == true + lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) else - lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + cons = cons, cons_j = cons_j!, cons_h = cons_h!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, cons_jac_prototype = cons_jac_prototype, cons_jac_colorvec = cons_jac_colorvec, cons_hess_prototype = conshess_sparsity, cons_hess_colorvec = conshess_colors, - lag_h, + lag_h = lag_h!, lag_hess_prototype = lag_hess_prototype, sys = f.sys, expr = f.expr, @@ -360,11 +413,14 @@ end function OptimizationBase.instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0) + adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0; + g = false, h = false, hv = false, fg = false, fgh = false, + cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false) x = cache.u0 p = cache.p - return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons) + return OptimizationBase.instantiate_function( + f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h) end end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index aac5062..de5024e 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -69,7 +69,8 @@ function instantiate_function( if fgh == true && f.fgh !== nothing function fgh!(G, H, θ) - (y, _, _) = value_derivative_and_second_derivative!(_f, G, H, soadtype, θ, extras_hess) + (y, _, _) = value_derivative_and_second_derivative!( + _f, G, H, soadtype, θ, extras_hess) return y end elseif fgh == true @@ -84,9 +85,9 @@ function instantiate_function( hvp!(_f, H, soadtype, θ, v, extras_hvp) end elseif hv == true - hv = (H, θ, v) -> f.hv(H, θ, v, p) + hv! = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = nothing + hv! = nothing end if f.cons === nothing @@ -189,7 +190,7 @@ function instantiate_function( end end end - elseif lag_h == true && cons !== nothing + elseif lag_h == true && cons !== nothing lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) else lag_h! = nothing diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 3c22b59..08c6dea 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -166,9 +166,9 @@ function instantiate_function( hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp) end elseif hv == true - hv = (H, θ, v) -> f.hv(H, θ, v, p) + hv! = (H, θ, v) -> f.hv(H, θ, v, p) else - hv = nothing + hv! = nothing end if f.cons === nothing @@ -178,7 +178,7 @@ function instantiate_function( f.cons(res, θ, p) end - function cons_oop(x, p=p) + function cons_oop(x, p = p) _res = zeros(eltype(x), num_cons) f.cons(_res, x, p) return _res @@ -231,7 +231,7 @@ function instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec - if cons !== nothing && f.cons_h === nothing && conshess == true + if cons !== nothing && f.cons_h === nothing && cons_h == true fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] extras_cons_hess = Vector(undef, length(fncs)) for ind in 1:num_cons @@ -380,7 +380,7 @@ function instantiate_function( elseif hv == true hv! = (θ, v) -> f.hv(θ, v, p) else - hv = f.hv + hv! = nothing end if f.cons === nothing @@ -409,9 +409,31 @@ function instantiate_function( cons_jac_prototype = extras_jac.sparsity cons_jac_colorvec = extras_jac.colors elseif cons_j === true && cons !== nothing - cons_j = (θ) -> f.cons_j(θ, p) + cons_j! = (θ) -> f.cons_j(θ, p) else - cons_j = nothing + cons_j! = nothing + end + + if f.cons_vjp === nothing && cons_vjp == true + extras_pullback = prepare_pullback(cons, adtype, x) + function cons_vjp!(θ, v) + pullback(cons, adtype, θ, v, extras_pullback) + end + elseif cons_vjp === true && cons !== nothing + cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p) + else + cons_vjp! = nothing + end + + if f.cons_jvp === nothing && cons_jvp == true + extras_pushforward = prepare_pushforward(cons, adtype, x) + function cons_jvp!(θ, v) + pushforward(cons, adtype, θ, v, extras_pushforward) + end + elseif cons_jvp === true && cons !== nothing + cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p) + else + cons_jvp! = nothing end conshess_sparsity = f.cons_hess_prototype diff --git a/src/function.jl b/src/function.jl index 63e194f..7592e1f 100644 --- a/src/function.jl +++ b/src/function.jl @@ -44,7 +44,7 @@ For more information on the use of automatic differentiation, see the documentation of the `AbstractADType` types. """ function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, - p, num_cons = 0) + p, num_cons = 0, kwargs...) grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) @@ -74,7 +74,7 @@ end function instantiate_function( f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD, - num_cons = 0) + num_cons = 0, kwargs...) grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...) hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...) hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...) @@ -103,7 +103,7 @@ function instantiate_function( end function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType, - p, num_cons = 0) + p, num_cons = 0, kwargs...) adtypestr = string(adtype) _strtind = findfirst('.', adtypestr) strtind = isnothing(_strtind) ? 5 : _strtind + 5 diff --git a/test/adtests.jl b/test/adtests.jl index f93814a..78f640e 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -30,7 +30,7 @@ cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2]; return nothing) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoModelingToolkit(), cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoModelingToolkit(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -54,7 +54,7 @@ optf = OptimizationFunction(rosenbrock, cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoModelingToolkit(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -75,7 +75,7 @@ H2 = Array{Float64}(undef, 2, 2) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoEnzyme(), cons = cons) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoEnzyme(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -94,7 +94,7 @@ H2 = Array{Float64}(undef, 2, 2) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoEnzyme(), cons = con2_c) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoEnzyme(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -115,7 +115,7 @@ H2 = Array{Float64}(undef, 2, 2) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoReverseDiff(), cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -136,7 +136,7 @@ H2 = Array{Float64}(undef, 2, 2) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoReverseDiff(), cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(compile = true), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -156,7 +156,7 @@ H2 = Array{Float64}(undef, 2, 2) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoZygote(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -175,7 +175,7 @@ optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoModelingToolkit(tru cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoModelingToolkit(true, true), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) using SparseArrays sH = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4)) @test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2] @@ -197,7 +197,7 @@ optprob.cons_h(sH3, x0) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoForwardDiff(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -207,7 +207,7 @@ optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoZygote(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -216,7 +216,7 @@ optprob.hess(H2, x0) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoReverseDiff()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1 == G2 optprob.hess(H2, x0) @@ -226,7 +226,7 @@ optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoTracker()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoTracker(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1 == G2 @test_broken optprob.hess(H2, x0) @@ -236,7 +236,7 @@ prob = OptimizationProblem(optf, x0) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff()) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-6 optprob.hess(H2, x0) @@ -247,7 +247,7 @@ cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2]) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = cons) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-6 optprob.hess(H2, x0) @@ -277,7 +277,7 @@ optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons cons_jac_colorvec = cons_jac_colors) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.cons_jac_prototype == sparse([1.0 1.0]) # make sure it's still using it @test optprob.cons_jac_colorvec == 1:2 J = zeros(1, 2) @@ -290,7 +290,7 @@ end optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = con2_c) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-6 optprob.hess(H2, x0) @@ -314,7 +314,7 @@ optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons cons_jac_colorvec = cons_jac_colors) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.cons_jac_prototype == sparse([1.0 1.0; 1.0 1.0]) # make sure it's still using it @test optprob.cons_jac_colorvec == 1:2 J = Array{Float64}(undef, 2, 2) @@ -333,7 +333,7 @@ optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff(), hess cons_jac_prototype = copy(sJ)) optprob1 = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoForwardDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob1.hess_prototype == sparse([0.0 0.0; 0.0 0.0]) # make sure it's still using it optprob1.hess(sH, [5.0, 3.0]) @test all(isapprox(sH, [28802.0 -2000.0; -2000.0 200.0]; rtol = 1e-3)) @@ -352,7 +352,9 @@ optf = OptimizationFunction(rosenbrock, SciMLBase.NoAD(), grad = grad, hess = he cons = con2_c, cons_j = cons_j, cons_h = cons_h, hess_prototype = sH, cons_jac_prototype = sJ, cons_hess_prototype = sH3) -optprob2 = OptimizationBase.instantiate_function(optf, x0, SciMLBase.NoAD(), nothing, 2) +optprob2 = OptimizationBase.instantiate_function( + optf, x0, SciMLBase.NoAD(), nothing, 2, g = true, + h = true, cons_j = true, cons_h = true) optprob2.hess(sH, [5.0, 3.0]) @test all(isapprox(sH, [28802.0 -2000.0; -2000.0 200.0]; rtol = 1e-3)) optprob2.cons_j(sJ, [5.0, 3.0]) @@ -368,7 +370,7 @@ optf = OptimizationFunction(rosenbrock, cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseFiniteDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) G2 = Array{Float64}(undef, 2) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-4 @@ -390,7 +392,7 @@ optprob.cons_h(H3, x0) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoSparseFiniteDiff()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseFiniteDiff(), - nothing) + nothing, g = true, h = true, cons_j = true, cons_h = true) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-6 optprob.hess(H2, x0) @@ -401,7 +403,7 @@ optf = OptimizationFunction(rosenbrock, cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseForwardDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) G2 = Array{Float64}(undef, 2) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-4 @@ -423,7 +425,7 @@ optprob.cons_h(H3, x0) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoSparseForwardDiff()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseForwardDiff(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-6 optprob.hess(H2, x0) @@ -434,7 +436,7 @@ optf = OptimizationFunction(rosenbrock, cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(true), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) G2 = Array{Float64}(undef, 2) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-4 @@ -458,7 +460,7 @@ optf = OptimizationFunction(rosenbrock, cons = con2_c) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) G2 = Array{Float64}(undef, 2) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-4 @@ -480,7 +482,7 @@ optprob.cons_h(H3, x0) optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoSparseReverseDiff()) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(), - nothing) + nothing, g = true, h = true) optprob.grad(G2, x0) @test G1≈G2 rtol=1e-6 optprob.hess(H2, x0) @@ -493,7 +495,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoEnzyme(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -510,7 +512,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoEnzyme(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -524,7 +526,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0)≈G1 rtol=1e-6 @test optprob.hess(x0)≈H1 rtol=1e-6 @@ -541,7 +543,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoFiniteDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0)≈G1 rtol=1e-6 @test optprob.hess(x0)≈H1 rtol=1e-6 @@ -555,7 +557,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoForwardDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -572,7 +574,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoForwardDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -586,7 +588,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -603,7 +605,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -617,7 +619,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(true), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -634,7 +636,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoReverseDiff(true), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -648,7 +650,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseForwardDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @@ -665,7 +667,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseForwardDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @@ -679,7 +681,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseFiniteDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0)≈G1 rtol=1e-4 @test Array(optprob.hess(x0)) ≈ H1 @@ -696,7 +698,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseForwardDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @@ -710,7 +712,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -727,7 +729,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @@ -741,7 +743,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(true), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -757,7 +759,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(true), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 @@ -771,7 +773,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoZygote(), - nothing, 1) + nothing, 1, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test optprob.hess(x0) == H1 @@ -787,7 +789,7 @@ optprob.hess(H2, x0) cons = cons) optprob = OptimizationBase.instantiate_function( optf, x0, OptimizationBase.AutoZygote(), - nothing, 2) + nothing, 2, g = true, h = true, cons_j = true, cons_h = true) @test optprob.grad(x0) == G1 @test Array(optprob.hess(x0)) ≈ H1 From 2792e2b6dc2a1a87be55b2bd3c93064c80fb4359 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 16 Aug 2024 11:27:57 -0400 Subject: [PATCH 27/33] update SpraseMatrixColoring v0.4 --- ext/OptimizationZygoteExt.jl | 17 ++++++++++------- src/OptimizationDISparseExt.jl | 32 ++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index a4721ae..fb24866 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -256,8 +256,8 @@ function OptimizationBase.instantiate_function( function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end - hess_sparsity = extras_hess.sparsity - hess_colors = extras_hess.colors + hess_sparsity = extras_hess.coloring_result.S + hess_colors = extras_hess.coloring_result.color elseif h == true hess = (H, θ) -> f.hess(H, θ, p) else @@ -314,8 +314,8 @@ function OptimizationBase.instantiate_function( J = vec(J) end end - cons_jac_prototype = extras_jac.sparsity - cons_jac_colorvec = extras_jac.colors + cons_jac_prototype = extras_jac.coloring_result.S + cons_jac_colorvec = extras_jac.coloring_result.color elseif cons !== nothing && cons_j == true cons_j! = (J, θ) -> f.cons_j(J, θ, p) else @@ -352,8 +352,9 @@ function OptimizationBase.instantiate_function( for ind in 1:num_cons extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) end - conshess_sparsity = getfield.(extras_cons_hess, :sparsity) - conshess_colors = getfield.(extras_cons_hess, :colors) + colores = getfield.(extras_cons_hess, :coloring_result) + conshess_sparsity = getfield.(colores, :S) + conshess_colors = getfield.(colores, :color) function cons_h!(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) @@ -368,7 +369,8 @@ function OptimizationBase.instantiate_function( lag_hess_prototype = f.lag_hess_prototype if cons !== nothing && cons_h == true && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) - lag_hess_prototype = lag_extras.sparsity + lag_hess_prototype = lag_extras.coloring_result.S + lag_hess_colors = lag_extras.coloring_result.color function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) @@ -406,6 +408,7 @@ function OptimizationBase.instantiate_function( cons_hess_colorvec = conshess_colors, lag_h = lag_h!, lag_hess_prototype = lag_hess_prototype, + lag_hess_colorvec = lag_hess_colors, sys = f.sys, expr = f.expr, cons_expr = f.cons_expr) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 08c6dea..4c129e7 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -141,8 +141,8 @@ function instantiate_function( function hess(res, θ) hessian!(_f, res, soadtype, θ, extras_hess) end - hess_sparsity = extras_hess.sparsity - hess_colors = extras_hess.colors + hess_sparsity = extras_hess.coloring_result.S + hess_colors = extras_hess.coloring_result.color elseif h == true hess = (H, θ) -> f.hess(H, θ, p) else @@ -199,8 +199,8 @@ function instantiate_function( J = vec(J) end end - cons_jac_prototype = extras_jac.sparsity - cons_jac_colorvec = extras_jac.colors + cons_jac_prototype = extras_jac.coloring_result.S + cons_jac_colorvec = extras_jac.coloring_result.color elseif cons_j === true && cons !== nothing cons_j! = (J, θ) -> f.cons_j(J, θ, p) else @@ -237,8 +237,9 @@ function instantiate_function( for ind in 1:num_cons extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) end - conshess_sparsity = getfield.(extras_cons_hess, :sparsity) - conshess_colors = getfield.(extras_cons_hess, :colors) + colores = getfield.(extras_cons_hess, :coloring_result) + conshess_sparsity = getfield.(colores, :S) + conshess_colors = getfield.(colores, :color) function cons_h!(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i]) @@ -254,8 +255,8 @@ function instantiate_function( lag_hess_colors = f.lag_hess_colorvec if cons !== nothing && lag_h == true && f.lag_h === nothing lag_extras = prepare_hessian(lagrangian, soadtype, x) - lag_hess_prototype = lag_extras.sparsity - lag_hess_colors = lag_extras.colors + lag_hess_prototype = lag_extras.coloring_result.S + lag_hess_colors = lag_extras.coloring_result.color function lag_h!(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) @@ -366,6 +367,8 @@ function instantiate_function( function hess(θ) hessian(_f, soadtype, θ, extras_hess) end + hess_sparsity = extras_hess.coloring_result.S + hess_colors = extras_hess.coloring_result.color elseif h == true hess = (θ) -> f.hess(θ, p) else @@ -406,8 +409,8 @@ function instantiate_function( end return J end - cons_jac_prototype = extras_jac.sparsity - cons_jac_colorvec = extras_jac.colors + cons_jac_prototype = extras_jac.coloring_result.S + cons_jac_colorvec = extras_jac.coloring_result.color elseif cons_j === true && cons !== nothing cons_j! = (θ) -> f.cons_j(θ, p) else @@ -448,8 +451,9 @@ function instantiate_function( end return H end - conshess_sparsity = getfield.(extras_cons_hess, :sparsity) - conshess_colors = getfield.(extras_cons_hess, :colors) + colores = getfield.(extras_cons_hess, :coloring_result) + conshess_sparsity = getfield.(colores, :S) + conshess_colors = getfield.(colores, :color) elseif cons_h == true && cons !== nothing cons_h! = (res, θ) -> f.cons_h(res, θ, p) else @@ -468,8 +472,8 @@ function instantiate_function( return hess end end - lag_hess_prototype = lag_extras.sparsity - lag_hess_colors = lag_extras.colors + lag_hess_prototype = lag_extras.coloring_result.S + lag_hess_colors = lag_extras.coloring_result.color elseif lag_h == true && cons !== nothing lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) else From c16195416d7d9618bd1b359ca2f925347eeb35d7 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 21 Aug 2024 21:22:05 -0400 Subject: [PATCH 28/33] Hack p to serve as data arg and implement stochastic gradient oracle --- ext/OptimizationEnzymeExt.jl | 4 +- ext/OptimizationZygoteExt.jl | 55 +++++++++++++++++++++++-- src/OptimizationDIExt.jl | 59 ++++++++++++++++++++++++--- src/OptimizationDISparseExt.jl | 53 +++++++++++++++++++++--- test/Project.toml | 1 + test/adtests.jl | 74 ++++++++++++++++++++++++++++++++++ 6 files changed, 230 insertions(+), 16 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index ea6dafe..40ce0ea 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -94,7 +94,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) if g == true && f.grad === nothing - function grad(res, θ) + function grad(res, θ, p = p) Enzyme.make_zero!(res) Enzyme.autodiff(Enzyme.Reverse, Const(firstapply), @@ -111,7 +111,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if fg == true && f.fg === nothing - function fg!(res, θ) + function fg!(res, θ, p = p) Enzyme.make_zero!(res) y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, Const(firstapply), diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index fb24866..fbccd43 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -7,7 +7,7 @@ import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I, dot import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, - prepare_jacobian, + prepare_jacobian, value_and_gradient!, value_derivative_and_second_derivative!, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian using ADTypes, SciMLBase @@ -19,8 +19,9 @@ function OptimizationBase.instantiate_function( g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + global _p = p function _f(θ) - return f(θ, p)[1] + return f(θ, _p)[1] end adtype, soadtype = OptimizationBase.generate_adtype(adtype) @@ -30,19 +31,41 @@ function OptimizationBase.instantiate_function( function grad(res, θ) gradient!(_f, res, adtype, θ, extras_grad) end + if p !== SciMLBase.NullParameters() && p !== nothing + function grad(res, θ, p) + global _p = p + gradient!(_f, res, adtype, θ) + end + end elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + grad = (G, θ, p) -> f.grad(G, θ, p) + end else grad = nothing end if fg == true && f.fg === nothing + if g == false + extras_grad = prepare_gradient(_f, adtype, x) + end function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) return y end + if p !== SciMLBase.NullParameters() && p !== nothing + function fg!(res, θ, p) + global _p = p + (y, _) = value_and_gradient!(_f, res, adtype, θ) + return y + end + end elseif fg == true fg! = (G, θ) -> f.fg(G, θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + fg! = (G, θ, p) -> f.fg(G, θ, p) + end else fg! = nothing end @@ -188,7 +211,8 @@ function OptimizationBase.instantiate_function( lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + return OptimizationFunction{true}(f.f, adtype; + grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, @@ -232,19 +256,41 @@ function OptimizationBase.instantiate_function( function grad(res, θ) gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end + if p !== SciMLBase.NullParameters() && p !== nothing + function grad(res, θ, p) + global p = p + gradient!(_f, res, adtype.dense_ad, θ) + end + end elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + grad = (G, θ, p) -> f.grad(G, θ, p) + end else grad = nothing end if fg == true && f.fg !== nothing + if g == false + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) + end function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) return y end + if p !== SciMLBase.NullParameters() && p !== nothing + function fg!(res, θ, p) + global p = p + (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ) + return y + end + end elseif fg == true fg! = (G, θ) -> f.fg(G, θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + fg! = (G, θ, p) -> f.fg(G, θ, p) + end else fg! = nothing end @@ -398,7 +444,8 @@ function OptimizationBase.instantiate_function( else lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + return OptimizationFunction{true}(f.f, adtype; + grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, hess_prototype = hess_sparsity, hess_colorvec = hess_colors, diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index de5024e..3dd0cdc 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -4,7 +4,8 @@ import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, - prepare_jacobian, + prepare_jacobian, value_and_gradient!, value_and_gradient, + value_derivative_and_second_derivative!, value_derivative_and_second_derivative, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian using ADTypes, SciMLBase @@ -26,8 +27,9 @@ function instantiate_function( g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + global _p = p function _f(θ) - return f(θ, p)[1] + return f(θ, _p)[1] end adtype, soadtype = generate_adtype(adtype) @@ -37,19 +39,41 @@ function instantiate_function( function grad(res, θ) gradient!(_f, res, adtype, θ, extras_grad) end + if p !== SciMLBase.NullParameters() && p !== nothing + function grad(res, θ, p) + global _p = p + gradient!(_f, res, adtype, θ) + end + end elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + grad = (G, θ, p) -> f.grad(G, θ, p) + end else grad = nothing end if fg == true && f.fg === nothing + if g == false + extras_grad = prepare_gradient(_f, adtype, x) + end function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad) return y end + if p !== SciMLBase.NullParameters() && p !== nothing + function fg!(res, θ, p) + global _p = p + (y, _) = value_and_gradient!(_f, res, adtype, θ) + return y + end + end elseif fg == true fg! = (G, θ) -> f.fg(G, θ, p) + if p !== SciMLBase.NullParameters() + fg! = (G, θ, p) -> f.fg(G, θ, p) + end else fg! = nothing end @@ -196,7 +220,8 @@ function instantiate_function( lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + return OptimizationFunction{true}(f.f, adtype; + grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, @@ -232,8 +257,9 @@ function instantiate_function( g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + global _p = p function _f(θ) - return f(θ, p)[1] + return f(θ, _p)[1] end adtype, soadtype = generate_adtype(adtype) @@ -243,19 +269,41 @@ function instantiate_function( function grad(θ) gradient(_f, adtype, θ, extras_grad) end + if p !== SciMLBase.NullParameters() && p !== nothing + function grad(θ, p) + global _p = p + gradient(_f, adtype, θ) + end + end elseif g == true grad = (θ) -> f.grad(θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + grad = (θ, p) -> f.grad(θ, p) + end else grad = nothing end if fg == true && f.fg === nothing + if g == false + extras_grad = prepare_gradient(_f, adtype, x) + end function fg!(θ) (y, res) = value_and_gradient(_f, adtype, θ, extras_grad) return y, res end + if p !== SciMLBase.NullParameters() && p !== nothing + function fg!(θ, p) + global _p = p + (y, res) = value_and_gradient(_f, adtype, θ) + return y, res + end + end elseif fg == true fg! = (θ) -> f.fg(θ, p) + if p !== SciMLBase.NullParameters() && p !== nothing + fg! = (θ, p) -> f.fg(θ, p) + end else fg! = nothing end @@ -387,7 +435,8 @@ function instantiate_function( lag_h! = nothing end - return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + return OptimizationFunction{false}(f.f, adtype; + grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 4c129e7..1d2bb02 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -4,7 +4,8 @@ import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, - prepare_jacobian, + prepare_jacobian, value_and_gradient!, value_derivative_and_second_derivative!, + value_and_gradient, value_derivative_and_second_derivative, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian using ADTypes @@ -106,8 +107,9 @@ function instantiate_function( g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + global _p = p function _f(θ) - return f.f(θ, p)[1] + return f.f(θ, _p)[1] end adtype, soadtype = generate_sparse_adtype(adtype) @@ -117,19 +119,41 @@ function instantiate_function( function grad(res, θ) gradient!(_f, res, adtype.dense_ad, θ, extras_grad) end + if p !== SciMLBase.NullParameters() + function grad(res, θ, p) + global _p = p + gradient!(_f, res, adtype.dense_ad, θ, extras_grad) + end + end elseif g == true grad = (G, θ) -> f.grad(G, θ, p) + if p !== SciMLBase.NullParameters() + grad = (G, θ, p) -> f.grad(G, θ, p) + end else grad = nothing end if fg == true && f.fg !== nothing + if g == false + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) + end function fg!(res, θ) (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) return y end + if p !== SciMLBase.NullParameters() + function fg!(res, θ, p) + global _p = p + (y, _) = value_and_gradient!(_f, res, adtype.dense_ad, θ, extras_grad) + return y + end + end elseif fg == true fg! = (G, θ) -> f.fg(G, θ, p) + if p !== SciMLBase.NullParameters() + fg! = (G, θ, p) -> f.fg(G, θ, p) + end else fg! = nothing end @@ -284,7 +308,8 @@ function instantiate_function( else lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + return OptimizationFunction{true}(f.f, adtype; + grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, @@ -321,8 +346,9 @@ function instantiate_function( g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) + global _p = p function _f(θ) - return f(θ, p)[1] + return f(θ, _p)[1] end adtype, soadtype = generate_sparse_adtype(adtype) @@ -332,6 +358,12 @@ function instantiate_function( function grad(θ) gradient(_f, adtype.dense_ad, θ, extras_grad) end + if p !== SciMLBase.NullParameters() && p !== nothing + function grad(θ, p) + global _p = p + gradient(_f, adtype.dense_ad, θ, extras_grad) + end + end elseif g == true grad = (θ) -> f.grad(θ, p) else @@ -339,10 +371,20 @@ function instantiate_function( end if fg == true && f.fg !== nothing + if g == false + extras_grad = prepare_gradient(_f, adtype.dense_ad, x) + end function fg!(θ) (y, G) = value_and_gradient(_f, adtype.dense_ad, θ, extras_grad) return y, G end + if p !== SciMLBase.NullParameters() && p !== nothing + function fg!(θ, p) + global _p = p + (y, G) = value_and_gradient(_f, adtype.dense_ad, θ, extras_grad) + return y, G + end + end elseif fg == true fg! = (θ) -> f.fg(θ, p) else @@ -479,7 +521,8 @@ function instantiate_function( else lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv!, + return OptimizationFunction{true}(f.f, adtype; + grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, hess_prototype = hess_sparsity, diff --git a/test/Project.toml b/test/Project.toml index ef0ec4f..59cc04b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/test/adtests.jl b/test/adtests.jl index 78f640e..9e07bd6 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -797,3 +797,77 @@ optprob.hess(H2, x0) @test optprob.cons_j([5.0, 3.0])≈[10.0 6.0; -0.149013 -0.958924] rtol=1e-6 @test Array.(optprob.cons_h(x0)) ≈ [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] end + +using MLUtils + +@testset "Stochastic gradient" begin + x = rand(10000) + y = sin.(x) + data = MLUtils.DataLoader((x,y), batchsize = 100) + + function loss(coeffs, data) + ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) + end + + optf = OptimizationFunction(loss, AutoForwardDiff()) + optf = OptimizationBase.instantiate_function(optf, rand(3), AutoForwardDiff(), iterate(data)[1], g = true, fg = true) + G0 = zeros(3) + optf.grad(G0, ones(3)) + stochgrads = [] + for (x,y) in data + G = zeros(3) + optf.grad(G, ones(3), (x,y)) + push!(stochgrads, copy(G)) + G1 = zeros(3) + optf.fg(G1, ones(3), (x,y)) + @test G ≈ G1 rtol=1e-6 + end + @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + + optf = OptimizationFunction(loss, AutoReverseDiff()) + optf = OptimizationBase.instantiate_function(optf, rand(3), AutoReverseDiff(), iterate(data)[1], g = true, fg = true) + G0 = zeros(3) + optf.grad(G0, ones(3)) + stochgrads = [] + for (x,y) in data + G = zeros(3) + optf.grad(G, ones(3), (x,y)) + push!(stochgrads, copy(G)) + G1 = zeros(3) + optf.fg(G1, ones(3), (x,y)) + @test G ≈ G1 rtol=1e-6 + end + @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + + optf = OptimizationFunction(loss, AutoZygote()) + optf = OptimizationBase.instantiate_function(optf, rand(3), AutoZygote(), iterate(data)[1], g = true, fg = true) + G0 = zeros(3) + optf.grad(G0, ones(3)) + stochgrads = [] + for (x,y) in data + G = zeros(3) + optf.grad(G, ones(3), (x,y)) + push!(stochgrads, copy(G)) + G1 = zeros(3) + optf.fg(G1, ones(3), (x,y)) + @test G ≈ G1 rtol=1e-6 + end + @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + + + optf = OptimizationFunction(loss, AutoEnzyme()) + optf = OptimizationBase.instantiate_function(optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true) + G0 = zeros(3) + @test_broken optf.grad(G0, ones(3)) + stochgrads = [] + # for (x,y) in data + # G = zeros(3) + # optf.grad(G, ones(3), (x,y)) + # push!(stochgrads, copy(G)) + # G1 = zeros(3) + # optf.fg(G1, ones(3), (x,y)) + # @test G ≈ G1 rtol=1e-6 + # end + # @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 +end \ No newline at end of file From 3e04244f268c49f8118fbc82ea4cd141ed12eed8 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 22 Aug 2024 08:44:42 -0400 Subject: [PATCH 29/33] use master project.toml --- Project.toml | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 918aadd..af66c85 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimizationBase" uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" authors = ["Vaibhav Dixit and contributors"] -version = "1.3.3" +version = "1.5.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -40,16 +40,21 @@ OptimizationZygoteExt = "Zygote" [compat] ADTypes = "1.5" ArrayInterface = "7.6" -DifferentiationInterface = "0.5.2" +DifferentiationInterface = "0.5" DocStringExtensions = "0.9" +Enzyme = "0.12.12" +FiniteDiff = "2.12" +ForwardDiff = "0.10.26" LinearAlgebra = "1.9, 1.10" ModelingToolkit = "9" +PDMats = "0.11" Reexport = "1.2" Requires = "1" +ReverseDiff = "1.14" SciMLBase = "2" -SymbolicAnalysis = "0.1, 0.2" +SymbolicAnalysis = "0.3" SymbolicIndexingInterface = "0.3" -Symbolics = "5.12" +Symbolics = "5.12, 6" Zygote = "0.6.67" julia = "1.10" @@ -57,4 +62,4 @@ julia = "1.10" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test"] \ No newline at end of file From 36963a4d9585fb17611eb7a3f2207a5c72ee5def Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 22 Aug 2024 09:46:31 -0400 Subject: [PATCH 30/33] format --- ext/OptimizationZygoteExt.jl | 7 ++--- src/OptimizationDIExt.jl | 7 ++--- src/OptimizationDISparseExt.jl | 7 ++--- test/adtests.jl | 49 ++++++++++++++++++---------------- 4 files changed, 38 insertions(+), 32 deletions(-) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index fbccd43..75f5650 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -7,7 +7,8 @@ import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I, dot import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, - prepare_jacobian, value_and_gradient!, value_derivative_and_second_derivative!, + prepare_jacobian, value_and_gradient!, + value_derivative_and_second_derivative!, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian using ADTypes, SciMLBase @@ -211,7 +212,7 @@ function OptimizationBase.instantiate_function( lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; + return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, @@ -444,7 +445,7 @@ function OptimizationBase.instantiate_function( else lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; + return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, hess_prototype = hess_sparsity, diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 3dd0cdc..dd6e829 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -5,7 +5,8 @@ import OptimizationBase.LinearAlgebra: I import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, prepare_jacobian, value_and_gradient!, value_and_gradient, - value_derivative_and_second_derivative!, value_derivative_and_second_derivative, + value_derivative_and_second_derivative!, + value_derivative_and_second_derivative, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian using ADTypes, SciMLBase @@ -220,7 +221,7 @@ function instantiate_function( lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; + return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, @@ -435,7 +436,7 @@ function instantiate_function( lag_h! = nothing end - return OptimizationFunction{false}(f.f, adtype; + return OptimizationFunction{false}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 1d2bb02..6018d21 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -4,7 +4,8 @@ import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, - prepare_jacobian, value_and_gradient!, value_derivative_and_second_derivative!, + prepare_jacobian, value_and_gradient!, + value_derivative_and_second_derivative!, value_and_gradient, value_derivative_and_second_derivative, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, hvp, jacobian @@ -308,7 +309,7 @@ function instantiate_function( else lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; + return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, @@ -521,7 +522,7 @@ function instantiate_function( else lag_h! = nothing end - return OptimizationFunction{true}(f.f, adtype; + return OptimizationFunction{true}(f.f, adtype; grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!, cons = cons, cons_j = cons_j!, cons_h = cons_h!, cons_vjp = cons_vjp!, cons_jvp = cons_jvp!, diff --git a/test/adtests.jl b/test/adtests.jl index 9e07bd6..30aca18 100644 --- a/test/adtests.jl +++ b/test/adtests.jl @@ -803,61 +803,64 @@ using MLUtils @testset "Stochastic gradient" begin x = rand(10000) y = sin.(x) - data = MLUtils.DataLoader((x,y), batchsize = 100) + data = MLUtils.DataLoader((x, y), batchsize = 100) function loss(coeffs, data) - ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])] + ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])] return sum(abs2, ypred .- data[2]) end optf = OptimizationFunction(loss, AutoForwardDiff()) - optf = OptimizationBase.instantiate_function(optf, rand(3), AutoForwardDiff(), iterate(data)[1], g = true, fg = true) + optf = OptimizationBase.instantiate_function( + optf, rand(3), AutoForwardDiff(), iterate(data)[1], g = true, fg = true) G0 = zeros(3) optf.grad(G0, ones(3)) stochgrads = [] - for (x,y) in data + for (x, y) in data G = zeros(3) - optf.grad(G, ones(3), (x,y)) + optf.grad(G, ones(3), (x, y)) push!(stochgrads, copy(G)) G1 = zeros(3) - optf.fg(G1, ones(3), (x,y)) - @test G ≈ G1 rtol=1e-6 + optf.fg(G1, ones(3), (x, y)) + @test G≈G1 rtol=1e-6 end - @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + @test G0≈sum(stochgrads) / length(stochgrads) rtol=1e-1 optf = OptimizationFunction(loss, AutoReverseDiff()) - optf = OptimizationBase.instantiate_function(optf, rand(3), AutoReverseDiff(), iterate(data)[1], g = true, fg = true) + optf = OptimizationBase.instantiate_function( + optf, rand(3), AutoReverseDiff(), iterate(data)[1], g = true, fg = true) G0 = zeros(3) optf.grad(G0, ones(3)) stochgrads = [] - for (x,y) in data + for (x, y) in data G = zeros(3) - optf.grad(G, ones(3), (x,y)) + optf.grad(G, ones(3), (x, y)) push!(stochgrads, copy(G)) G1 = zeros(3) - optf.fg(G1, ones(3), (x,y)) - @test G ≈ G1 rtol=1e-6 + optf.fg(G1, ones(3), (x, y)) + @test G≈G1 rtol=1e-6 end - @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + @test G0≈sum(stochgrads) / length(stochgrads) rtol=1e-1 optf = OptimizationFunction(loss, AutoZygote()) - optf = OptimizationBase.instantiate_function(optf, rand(3), AutoZygote(), iterate(data)[1], g = true, fg = true) + optf = OptimizationBase.instantiate_function( + optf, rand(3), AutoZygote(), iterate(data)[1], g = true, fg = true) G0 = zeros(3) optf.grad(G0, ones(3)) stochgrads = [] - for (x,y) in data + for (x, y) in data G = zeros(3) - optf.grad(G, ones(3), (x,y)) + optf.grad(G, ones(3), (x, y)) push!(stochgrads, copy(G)) G1 = zeros(3) - optf.fg(G1, ones(3), (x,y)) - @test G ≈ G1 rtol=1e-6 + optf.fg(G1, ones(3), (x, y)) + @test G≈G1 rtol=1e-6 end - @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 + @test G0≈sum(stochgrads) / length(stochgrads) rtol=1e-1 - optf = OptimizationFunction(loss, AutoEnzyme()) - optf = OptimizationBase.instantiate_function(optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true) + optf = OptimizationBase.instantiate_function( + optf, rand(3), AutoEnzyme(), iterate(data)[1], g = true, fg = true) G0 = zeros(3) @test_broken optf.grad(G0, ones(3)) stochgrads = [] @@ -870,4 +873,4 @@ using MLUtils # @test G ≈ G1 rtol=1e-6 # end # @test G0 ≈ sum(stochgrads)/length(stochgrads) rtol=1e-1 -end \ No newline at end of file +end From ec157c3a62173dfb6403797b56006128f9cc35e6 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 22 Aug 2024 11:31:00 -0400 Subject: [PATCH 31/33] Update NoAD dispatch --- src/function.jl | 87 ++++++++++++++++++++----------------------------- 1 file changed, 36 insertions(+), 51 deletions(-) diff --git a/src/function.jl b/src/function.jl index 81d4f1a..3cd52f8 100644 --- a/src/function.jl +++ b/src/function.jl @@ -105,64 +105,49 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReIn observed = f.observed) end - function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, - p, num_cons = 0) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) - cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) - cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) - cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) - hess_prototype = f.hess_prototype === nothing ? nothing : - convert.(eltype(x), f.hess_prototype) - cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : - convert.(eltype(x), f.cons_jac_prototype) - cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : - [convert.(eltype(x), f.cons_hess_prototype[i]) - for i in 1:num_cons] - expr = symbolify(f.expr) - cons_expr = symbolify.(f.cons_expr) - - return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess, - hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_prototype, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = cons_hess_prototype, - expr = expr, cons_expr = cons_expr, - sys = f.sys, - observed = f.observed) + p, num_cons = 0, kwargs...) + grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) + fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...) + hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) + fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...) + hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) + cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) + cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p) + cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p) + cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) + lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p) + hess_prototype = f.hess_prototype === nothing ? nothing : + convert.(eltype(x), f.hess_prototype) + cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : + convert.(eltype(x), f.cons_jac_prototype) + cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : + [convert.(eltype(x), f.cons_hess_prototype[i]) + for i in 1:num_cons] + expr = symbolify(f.expr) + cons_expr = symbolify.(f.cons_expr) + + return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); + grad = grad, fg = fg, hess = hess, fgh = fgh, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + cons_vjp = cons_vjp, cons_jvp = cons_jvp, + lag_h = lag_h, + hess_prototype = hess_prototype, + cons_jac_prototype = cons_jac_prototype, + cons_hess_prototype = cons_hess_prototype, + expr = expr, cons_expr = cons_expr, + sys = f.sys, + observed = f.observed) end function instantiate_function( f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD, num_cons = 0, kwargs...) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...) - cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p) - cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p) - cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p) - hess_prototype = f.hess_prototype === nothing ? nothing : - convert.(eltype(cache.u0), f.hess_prototype) - cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : - convert.(eltype(cache.u0), f.cons_jac_prototype) - cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : - [convert.(eltype(cache.u0), f.cons_hess_prototype[i]) - for i in 1:num_cons] - expr = symbolify(f.expr) - cons_expr = symbolify.(f.cons_expr) + x = cache.u0 + p = cache.p - return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess, - hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - hess_prototype = hess_prototype, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = cons_hess_prototype, - expr = expr, cons_expr = cons_expr, - sys = f.sys, - observed = f.observed) + return instantiate_function(f, x, SciMLBase.NoAD(), p, num_cons, kwargs...) end function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType, From ec3067b4d9537fd16158201829fe5d3488b6b50f Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 22 Aug 2024 13:28:34 -0400 Subject: [PATCH 32/33] Traits based boolean switching in instantiate_function call --- src/cache.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cache.jl b/src/cache.jl index 44fbff7..402c345 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -35,7 +35,10 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA kwargs...) reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p) num_cons = prob.ucons === nothing ? 0 : length(prob.ucons) - f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons) + f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons, + g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), + fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt), + cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) if (f.sys === nothing || f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) && From 54162cfbd149c1ec465053861a22dee1fda7d7cb Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 23 Aug 2024 06:37:16 -0400 Subject: [PATCH 33/33] format --- src/cache.jl | 9 +++--- src/function.jl | 81 +++++++++++++++++++++++++------------------------ 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/src/cache.jl b/src/cache.jl index 402c345..6dd196a 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -35,10 +35,11 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA kwargs...) reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p) num_cons = prob.ucons === nothing ? 0 : length(prob.ucons) - f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons, - g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), - fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt), - cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) + f = OptimizationBase.instantiate_function( + prob.f, reinit_cache, prob.f.adtype, num_cons, + g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt), + fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt), + cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt)) if (f.sys === nothing || f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) && diff --git a/src/function.jl b/src/function.jl index 3cd52f8..1343900 100644 --- a/src/function.jl +++ b/src/function.jl @@ -46,7 +46,8 @@ documentation of the `AbstractADType` types. function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD, p, num_cons = 0) jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...) - hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, p, args...) for h in f.hess] + hess = f.hess === nothing ? nothing : + [(H, x, args...) -> h(H, x, p, args...) for h in f.hess] hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) @@ -63,7 +64,8 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB expr = symbolify(f.expr) cons_expr = symbolify.(f.cons_expr) - return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess, + return MultiObjectiveOptimizationFunction{true}( + f.f, SciMLBase.NoAD(); jac = jac, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h, hess_prototype = hess_prototype, @@ -74,10 +76,12 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB observed = f.observed) end -function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD, +function instantiate_function( + f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD, num_cons = 0) jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...) - hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess] + hess = f.hess === nothing ? nothing : + [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess] hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...) cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p) cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p) @@ -94,7 +98,8 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReIn expr = symbolify(f.expr) cons_expr = symbolify.(f.cons_expr) - return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess, + return MultiObjectiveOptimizationFunction{true}( + f.f, SciMLBase.NoAD(); jac = jac, hess = hess, hv = hv, cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h, hess_prototype = hess_prototype, @@ -107,38 +112,38 @@ end function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD, p, num_cons = 0, kwargs...) - grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) - fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...) - hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) - fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...) - hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) - cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) - cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) - cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p) - cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p) - cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) - lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p) - hess_prototype = f.hess_prototype === nothing ? nothing : - convert.(eltype(x), f.hess_prototype) - cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : - convert.(eltype(x), f.cons_jac_prototype) - cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : - [convert.(eltype(x), f.cons_hess_prototype[i]) - for i in 1:num_cons] - expr = symbolify(f.expr) - cons_expr = symbolify.(f.cons_expr) - - return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); - grad = grad, fg = fg, hess = hess, fgh = fgh, hv = hv, - cons = cons, cons_j = cons_j, cons_h = cons_h, - cons_vjp = cons_vjp, cons_jvp = cons_jvp, - lag_h = lag_h, - hess_prototype = hess_prototype, - cons_jac_prototype = cons_jac_prototype, - cons_hess_prototype = cons_hess_prototype, - expr = expr, cons_expr = cons_expr, - sys = f.sys, - observed = f.observed) + grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) + fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...) + hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...) + fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...) + hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) + cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) + cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p) + cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p) + cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) + lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p) + hess_prototype = f.hess_prototype === nothing ? nothing : + convert.(eltype(x), f.hess_prototype) + cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : + convert.(eltype(x), f.cons_jac_prototype) + cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : + [convert.(eltype(x), f.cons_hess_prototype[i]) + for i in 1:num_cons] + expr = symbolify(f.expr) + cons_expr = symbolify.(f.cons_expr) + + return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); + grad = grad, fg = fg, hess = hess, fgh = fgh, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, + cons_vjp = cons_vjp, cons_jvp = cons_jvp, + lag_h = lag_h, + hess_prototype = hess_prototype, + cons_jac_prototype = cons_jac_prototype, + cons_hess_prototype = cons_hess_prototype, + expr = expr, cons_expr = cons_expr, + sys = f.sys, + observed = f.observed) end function instantiate_function( @@ -162,5 +167,3 @@ function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.Abstra adpkg = adtypestr[strtind:(open_brkt_ind - 1)] throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg.")) end - -