Skip to content

Commit 95d9480

Browse files
committed
Line Search for Gauss Newton
1 parent 674b872 commit 95d9480

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

src/gaussnewton.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
33
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced GaussNewton implementation with support for efficient handling of sparse
@@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems.
3030
preconditioners. For more information on specifying preconditioners for LinearSolve
3131
algorithms, consult the
3232
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
33+
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
34+
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
35+
used here directly, and they will be converted to the correct `LineSearch`.
3336
3437
!!! warning
3538
@@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems.
4043
ad::AD
4144
linsolve
4245
precs
46+
linesearch
4347
end
4448

4549
function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
46-
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
50+
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
4751
end
4852

4953
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
50-
precs = DEFAULT_PRECS, adkwargs...)
54+
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
5155
ad = default_adargs_to_adtype(; adkwargs...)
52-
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
56+
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
57+
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
5358
end
5459

5560
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -78,6 +83,7 @@ end
7883
stats::NLStats
7984
tc_cache_1
8085
tc_cache_2
86+
ls_cache
8187
end
8288

8389
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
@@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
107113

108114
return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
109115
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
110-
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2)
116+
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
117+
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
111118
end
112119

113120
function perform_step!(cache::GaussNewtonCache{true})
@@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true})
128135
linu = _vec(du), p, reltol = cache.abstol)
129136
end
130137
cache.linsolve = linres.cache
131-
@. u = u - du
138+
α = perform_linesearch!(cache.ls_cache, u, du)
139+
_axpy!(-α, du, u)
132140
f(cache.fu_new, u, p)
133141

134142
check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
@@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false})
169177
end
170178
cache.linsolve = linres.cache
171179
end
172-
cache.u = @. u - cache.du # `u` might not support mutation
180+
α = perform_linesearch!(cache.ls_cache, u, cache.du)
181+
cache.u = @. u - α * cache.du # `u` might not support mutation
173182
cache.fu_new = f(cache.u, p)
174183

175184
check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)

src/linesearch.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
122122
end
123123

124124
function g!(u, fu)
125+
# FIXME: Upstream patch to allow non-square Jacobians
125126
op = VecJac((args...) -> f(args..., p), u; autodiff)
126127
if iip
127128
mul!(g₀, op, fu)

src/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
2+
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
33
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced NewtonRaphson implementation with support for efficient handling of sparse

test/nonlinear_least_squares.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = [
31-
GaussNewton(),
32-
GaussNewton(; linsolve = LUFactorization()),
33-
LevenbergMarquardt(),
34-
LevenbergMarquardt(; linsolve = LUFactorization()),
35-
LeastSquaresOptimJL(:lm),
36-
LeastSquaresOptimJL(:dogleg),
37-
]
30+
solvers = vec(Any[GaussNewton(; linsolve, linesearch)
31+
for linsolve in [nothing, LUFactorization()],
32+
linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]])
33+
append!(solvers,
34+
[
35+
LevenbergMarquardt(),
36+
LevenbergMarquardt(; linsolve = LUFactorization()),
37+
LeastSquaresOptimJL(:lm),
38+
LeastSquaresOptimJL(:dogleg),
39+
])
3840

3941
for prob in nlls_problems, solver in solvers
4042
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)