Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/SpellCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ jobs:
- name: Checkout Actions Repository
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/typos@v1.23.6
uses: crate-ci/typos@v1.24.1
16 changes: 9 additions & 7 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import OptimizationBase.SciMLBase: OptimizationFunction
import OptimizationBase.LinearAlgebra: I
import DifferentiationInterface
import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
prepare_jacobian, value_and_gradient!, value_and_gradient,
prepare_jacobian, prepeare_pullback, value_and_gradient!,
value_and_gradient,
value_derivative_and_second_derivative!,
value_derivative_and_second_derivative,
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
hvp, jacobian
hvp, jacobian, pullback!
using ADTypes, SciMLBase

function generate_adtype(adtype)
Expand Down Expand Up @@ -150,12 +151,13 @@ function instantiate_function(
end

if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
extras_pullback = prepare_pullback(cons_oop, adtype, x)
function cons_vjp!(J, θ, v)
dy = similar(x)
extras_pullback = prepare_pullback(cons_oop, adtype, x, dy)
function cons_vjp!(J, v, θ)
pullback!(cons_oop, J, adtype, θ, v, extras_pullback)
end
elseif cons_vjp == true && cons !== nothing
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
cons_vjp! = (J, v, θ) -> f.cons_vjp(J, v, θ, p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't think this is correct. We don't want to have this order as API too since it'll be inconsistent with jvp and be higher mental load.

What makes you think the current implementation is incorrect? https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/dev/api/#DifferentiationInterface.pullback and https://github.com/jump-dev/MathOptInterface.jl/blob/cb1ad41c3a5ce1d12e83f2be60f10b191b0e22c5/src/nlp.jl#L762-L809 both match my understanding as well

I did spend a bit of time making sure the dimensions here were correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you call DI, the argument order is pullback!(cons_oop, J, adtype, θ, v, extras_pullback), but in SciML, the argument order is

- `cons_jvp(Jv,v,x,p)` or `Jv=cons_jvp(v,x,p)`: the Jacobian-vector product of the constraints.
- `cons_vjp(Jv,v,x,p)` or `Jv=cons_vjp(v,x,p)`: the Jacobian-vector product of the constraints.

https://github.com/SciML/SciMLBase.jl/blob/master/src/scimlfunctions.jl#L1833C1-L1834C96

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah lol so I missed the reference closest to home, I'll change that, this order makes less sense to me than the one used here. What do you think?

Copy link
Contributor Author

@baggepinnen baggepinnen Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion here. I'd almost expect v'J(x) to take args v, x and vice versa for J(x)*v, but having them be different might also be confusing. As long as SciML is internally consistent everywhere it can be whatever you prefer.

hv(Hv,u,v,p) or Hv=hv(u,v,p): the Hessian-vector product (d^2 f / du^2) v.takesvafteru, maybe change utox` to be consistent as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks for pointing these out 🙌

else
cons_vjp! = nothing
end
Expand All @@ -166,7 +168,7 @@ function instantiate_function(
pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward)
end
elseif cons_jvp == true && cons !== nothing
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
cons_jvp! = (J, v, θ) -> f.cons_jvp(J, v, θ, p)
else
cons_jvp! = nothing
end
Expand Down Expand Up @@ -240,7 +242,7 @@ end

function instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::ADTypes.AbstractADType, num_cons = 0,
adtype::ADTypes.AbstractADType, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
Expand Down
11 changes: 6 additions & 5 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
value_derivative_and_second_derivative!,
value_and_gradient, value_derivative_and_second_derivative,
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
hvp, jacobian
hvp, jacobian, prepare_pullback
using ADTypes
using SparseConnectivityTracer, SparseMatrixColorings

Expand Down Expand Up @@ -233,12 +233,13 @@ function instantiate_function(
end

if f.cons_vjp === nothing && cons_vjp == true
extras_pullback = prepare_pullback(cons_oop, adtype, x)
function cons_vjp!(J, θ, v)
dy = similar(x)
extras_pullback = prepare_pullback(cons_oop, adtype, x, dy)
function cons_vjp!(J, v, θ)
pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback)
end
elseif cons_vjp === true && cons !== nothing
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
cons_vjp! = (J, v, θ) -> f.cons_vjp(J, v, θ, p)
else
cons_vjp! = nothing
end
Expand All @@ -249,7 +250,7 @@ function instantiate_function(
pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward)
end
elseif cons_jvp === true && cons !== nothing
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
cons_jvp! = (J, v, θ) -> f.cons_jvp(J, v, θ, p)
else
cons_jvp! = nothing
end
Expand Down