Skip to content

Commit 5c89c59

Browse files
Merge pull request #229 from avik-pal/ap/concrete_jac
Fix Jacobian Construction for concrete_jac
2 parents fd4ae4b + b5b9298 commit 5c89c59

File tree

6 files changed

+83
-11
lines changed

6 files changed

+83
-11
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.0.1"
4+
version = "2.0.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -35,6 +35,7 @@ FiniteDiff = "2"
3535
ForwardDiff = "0.10.3"
3636
LineSearches = "7"
3737
LinearSolve = "2"
38+
NonlinearProblemLibrary = "0.1"
3839
PrecompileTools = "1"
3940
RecursiveArrayTools = "2"
4041
Reexport = "0.2, 1"
@@ -52,6 +53,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5253
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5354
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5455
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
56+
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
5557
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5658
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5759
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -62,4 +64,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6264
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6365

6466
[targets]
65-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools"]
67+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"]

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m
44
@eval Base.Experimental.@max_methods 1
55
end
66

7-
using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
7+
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
88
import ForwardDiff
99

1010
import ADTypes: AbstractFiniteDifferencesMode

src/jacobian.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@concrete struct JacobianWrapper{iip}
1+
@concrete struct JacobianWrapper{iip} <: Function
22
f
33
p
44
end
@@ -73,9 +73,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
7373
jac_cache = nothing
7474
end
7575

76-
J = if !linsolve_needs_jac
76+
J = if !(linsolve_needs_jac || alg_wants_jac)
7777
# We don't need to construct the Jacobian
78-
JacVec(uf, u; autodiff = alg.ad)
78+
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
7979
else
8080
if has_analytic_jac
8181
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
@@ -98,6 +98,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9898
return uf, linsolve, J, fu, jac_cache, du
9999
end
100100

101+
__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
102+
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
103+
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
104+
__get_nonsparse_ad(ad) = ad
105+
101106
## Special Handling for Scalars
102107
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
103108
::Val{false}; kwargs...)

src/levenberg.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
216216
# The following lines do: cache.a = -J \ cache.fu_tmp
217217
mul!(cache.du, J, v)
218218
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.du)
219-
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(cache.fu_tmp),
219+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
220220
linu = _vec(cache.du), p = p, reltol = cache.abstol)
221221
cache.linsolve = linres.cache
222222
@. cache.a = -cache.du
@@ -225,7 +225,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
225225

226226
# Require acceptable steps to satisfy the following condition.
227227
norm_v = norm(v)
228-
if (2 * norm(cache.a) / norm_v) < α_geodesic
228+
if 2 * norm(cache.a) α_geodesic * norm_v
229229
@. cache.δ = v + cache.a / 2
230230
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
231231
f(cache.fu_tmp, u .+ δ, p)
@@ -274,18 +274,19 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
274274
end
275275
@unpack u, p, λ, JᵀJ, DᵀD, J = cache
276276

277+
cache.mat_tmp = JᵀJ + λ * DᵀD
277278
# Usual Levenberg-Marquardt step ("velocity").
278-
cache.v = -(JᵀJ + λ * DᵀD) \ (J' * fu1)
279+
cache.v = -cache.mat_tmp \ (J' * fu1)
279280

280281
@unpack v, h, α_geodesic = cache
281282
# Geodesic acceleration (step_size = v + a / 2).
282-
cache.a = -J \ ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v))
283+
cache.a = -cache.mat_tmp \ ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v))
283284
cache.stats.nsolve += 1
284285
cache.stats.nfactors += 1
285286

286287
# Require acceptable steps to satisfy the following condition.
287288
norm_v = norm(v)
288-
if (2 * norm(cache.a) / norm_v) < α_geodesic
289+
if 2 * norm(cache.a) α_geodesic * norm_v
289290
cache.δ = v .+ cache.a ./ 2
290291
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
291292
fu_new = f(u .+ δ, p)

test/23_test_problems.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using NonlinearSolve, LinearAlgebra, NonlinearProblemLibrary, Test
2+
3+
problems = NonlinearProblemLibrary.problems
4+
dicts = NonlinearProblemLibrary.dicts
5+
6+
function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-5)
7+
for (idx, (problem, dict)) in enumerate(zip(problems, dicts))
8+
x = dict["start"]
9+
res = similar(x)
10+
nlprob = NonlinearProblem(problem, x)
11+
@testset "$(dict["title"])" begin
12+
for alg in alg_ops
13+
sol = solve(nlprob, alg, abstol = 1e-15, reltol = 1e-15)
14+
problem(res, sol.u, nothing)
15+
broken = idx in broken_tests[alg] ? true : false
16+
@test norm(res)ϵ broken=broken
17+
end
18+
end
19+
end
20+
end
21+
22+
# NewtonRaphson
23+
@testset "NewtonRaphson test problem library" begin
24+
alg_ops = (NewtonRaphson(),)
25+
26+
# dictionary with indices of test problems where method does not converge to small residual
27+
broken_tests = Dict(alg => Int[] for alg in alg_ops)
28+
broken_tests[alg_ops[1]] = [1, 6]
29+
30+
test_on_library(problems, dicts, alg_ops, broken_tests)
31+
end
32+
33+
@testset "TrustRegion test problem library" begin
34+
alg_ops = (TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Simple),
35+
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan),
36+
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Hei),
37+
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan),
38+
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin),
39+
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.NLsolve))
40+
41+
# dictionary with indices of test problems where method does not converge to small residual
42+
broken_tests = Dict(alg => Int[] for alg in alg_ops)
43+
broken_tests[alg_ops[1]] = [6, 11, 21]
44+
broken_tests[alg_ops[2]] = [6, 11, 21]
45+
broken_tests[alg_ops[3]] = [1, 6, 11, 12, 15, 16, 21]
46+
broken_tests[alg_ops[4]] = [1, 6, 8, 11, 15, 16, 21, 22]
47+
broken_tests[alg_ops[5]] = [6, 21]
48+
broken_tests[alg_ops[6]] = [6, 21]
49+
50+
test_on_library(problems, dicts, alg_ops, broken_tests)
51+
end
52+
53+
@testset "TrustRegion test problem library" begin
54+
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.5))
55+
56+
# dictionary with indices of test problems where method does not converge to small residual
57+
broken_tests = Dict(alg => Int[] for alg in alg_ops)
58+
broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21]
59+
broken_tests[alg_ops[2]] = [3, 6, 11, 21]
60+
61+
test_on_library(problems, dicts, alg_ops, broken_tests)
62+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ end
1515
if GROUP == "All" || GROUP == "Core"
1616
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
18+
19+
@time @safetestset "23 Test Problems" include("23_test_problems.jl")
1820
end
1921

2022
if GROUP == "GPU"

0 commit comments

Comments
 (0)