From 62b7388b00ffe8da7fffbe4417e894929f261ce4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 27 Aug 2023 18:36:19 -0400 Subject: [PATCH 1/2] Fix SteadyStateAdjoint Breakage --- Project.toml | 2 +- src/steadystate_adjoint.jl | 3 ++- test/steady_state.jl | 21 +++++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f6a9cb20c..4e5567a42 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.37.1" +version = "7.37.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index af73558a2..e24ba04c1 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -101,7 +101,8 @@ end if !needs_jac # operator = VecJac(f, y, p; Val(DiffEqBase.isinplace(sol.prob))) - operator = VecJac(f, y, p; autodiff = get_autodiff_from_vjp(vjp)) + __f = y -> f(y, p, nothing) + operator = VecJac(__f, y; autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) else linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ)) diff --git a/test/steady_state.jl b/test/steady_state.jl index 1f50a11aa..b3b4c3380 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -424,6 +424,27 @@ end @test dp1≈dp6 rtol=1e-10 @test dp1≈dp7 rtol=1e-10 @test dp1≈dp8 rtol=1e-10 + + # Larger Batched Problem: For testing the Iterative Solvers Path + u0 = zeros(128) + p = [2.0, 1.0] + + prob = NonlinearProblem((u, p) -> u .- p[1] .+ p[2], u0, p) + solve1 = solve(remake(prob, p = p), NewtonRaphson()) + + function test_loss(p, prob; alg = NewtonRaphson()) + _prob = remake(prob, p = p) + sol = sum(solve(_prob, alg, + sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()))) + return sol + end + + test_loss(p, prob) + + dp1 = Zygote.gradient(p -> test_loss(p, prob), p)[1] + + @test dp1[1] ≈ 128 + @test dp1[2] ≈ -128 end @testset "Continuous sensitivity tools" begin From ed284425cf45b1db2bfa62a8b083f28a92b97611 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 28 Aug 2023 07:11:24 -0400 Subject: [PATCH 2/2] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4e5567a42..c850551bf 100644 --- a/Project.toml +++ b/Project.toml @@ -73,7 +73,7 @@ ReverseDiff = "1.9" SciMLBase = "1.66.0" SciMLOperators = "0.1, 0.2, 0.3" SimpleNonlinearSolve = "0.1.8" -SparseDiffTools = "2.4" +SparseDiffTools = "2.5" StaticArraysCore = "1.4" StochasticDiffEq = "6.20" Tracker = "0.2"