diff --git a/.github/workflows/SpellCheck.yml b/.github/workflows/SpellCheck.yml index c6aa688..63a4a7e 100644 --- a/.github/workflows/SpellCheck.yml +++ b/.github/workflows/SpellCheck.yml @@ -10,4 +10,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.6 \ No newline at end of file + uses: crate-ci/typos@v1.24.1 \ No newline at end of file diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index dd6e829..0b1490e 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -4,11 +4,12 @@ import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.LinearAlgebra: I import DifferentiationInterface import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, - prepare_jacobian, value_and_gradient!, value_and_gradient, + prepare_jacobian, prepeare_pullback, value_and_gradient!, + value_and_gradient, value_derivative_and_second_derivative!, value_derivative_and_second_derivative, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, - hvp, jacobian + hvp, jacobian, pullback! using ADTypes, SciMLBase function generate_adtype(adtype) @@ -150,12 +151,13 @@ function instantiate_function( end if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing - extras_pullback = prepare_pullback(cons_oop, adtype, x) - function cons_vjp!(J, θ, v) + dy = similar(x) + extras_pullback = prepare_pullback(cons_oop, adtype, x, dy) + function cons_vjp!(J, v, θ) pullback!(cons_oop, J, adtype, θ, v, extras_pullback) end elseif cons_vjp == true && cons !== nothing - cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) + cons_vjp! = (J, v, θ) -> f.cons_vjp(J, v, θ, p) else cons_vjp! = nothing end @@ -166,7 +168,7 @@ function instantiate_function( pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward) end elseif cons_jvp == true && cons !== nothing - cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) + cons_jvp! = (J, v, θ) -> f.cons_jvp(J, v, θ, p) else cons_jvp! = nothing end @@ -240,7 +242,7 @@ end function instantiate_function( f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache, - adtype::ADTypes.AbstractADType, num_cons = 0, + adtype::ADTypes.AbstractADType, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false, lag_h = false) diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index 6018d21..26535aa 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -8,7 +8,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp, value_derivative_and_second_derivative!, value_and_gradient, value_derivative_and_second_derivative, gradient!, hessian!, hvp!, jacobian!, gradient, hessian, - hvp, jacobian + hvp, jacobian, prepare_pullback using ADTypes using SparseConnectivityTracer, SparseMatrixColorings @@ -233,12 +233,13 @@ function instantiate_function( end if f.cons_vjp === nothing && cons_vjp == true - extras_pullback = prepare_pullback(cons_oop, adtype, x) - function cons_vjp!(J, θ, v) + dy = similar(x) + extras_pullback = prepare_pullback(cons_oop, adtype, x, dy) + function cons_vjp!(J, v, θ) pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback) end elseif cons_vjp === true && cons !== nothing - cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p) + cons_vjp! = (J, v, θ) -> f.cons_vjp(J, v, θ, p) else cons_vjp! = nothing end @@ -249,7 +250,7 @@ function instantiate_function( pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward) end elseif cons_jvp === true && cons !== nothing - cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p) + cons_jvp! = (J, v, θ) -> f.cons_jvp(J, v, θ, p) else cons_jvp! = nothing end