Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 061ec7d

Browse files
committed
LBroyden
1 parent 409d36c commit 061ec7d

File tree

6 files changed

+210
-119
lines changed

6 files changed

+210
-119
lines changed

ext/SimpleNonlinearSolveADLinearSolveExt.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
module SimpleNonlinearSolveADLinearSolveExt
22

3-
using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
3+
using AbstractDifferentiation,
4+
ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
45
SimpleNonlinearSolve, SciMLBase
5-
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
6+
import SimpleNonlinearSolve: _construct_batched_problem_structure,
7+
_get_storage, _result_from_storage, _get_tolerance, @maybeinplace
68

79
const AD = AbstractDifferentiation
810

@@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}()
2022
# TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl
2123
chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size
2224
ad = SciMLBase._unwrap_val(autodiff) ?
23-
AD.ForwardDiffBackend(; chunksize) :
24-
AD.FiniteDifferencesBackend()
25-
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(
26-
ad,
25+
AD.ForwardDiffBackend(; chunksize) :
26+
AD.FiniteDifferencesBackend()
27+
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad,
2728
nothing,
2829
termination_condition)
2930
end
3031

3132
function SciMLBase.__solve(prob::NonlinearProblem,
3233
alg::SimpleBatchedNewtonRaphson;
33-
abstol=nothing,
34-
reltol=nothing,
35-
maxiters=1000,
34+
abstol = nothing,
35+
reltol = nothing,
36+
maxiters = 1000,
3637
kwargs...)
3738
iip = isinplace(prob)
3839
@assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems."
@@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem,
5758
alg,
5859
reconstruct(xₙ),
5960
reconstruct(fₙ);
60-
retcode=ReturnCode.Success)
61+
retcode = ReturnCode.Success)
6162

62-
solve(LinearProblem(𝓙, vec(fₙ); u0=vec(δx)), alg.linsolve; kwargs...)
63+
solve(LinearProblem(𝓙, vec(fₙ); u0 = vec(δx)), alg.linsolve; kwargs...)
6364
xₙ .-= δx
6465

6566
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
@@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
8384
alg,
8485
reconstruct(xₙ),
8586
reconstruct(fₙ);
86-
retcode=ReturnCode.MaxIters)
87+
retcode = ReturnCode.MaxIters)
8788
end
8889

8990
end

ext/SimpleNonlinearSolveNNlibExt.jl

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
module SimpleNonlinearSolveNNlibExt
22

33
using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
4-
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
4+
import SimpleNonlinearSolve: _construct_batched_problem_structure,
5+
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
56

67
function __init__()
78
SimpleNonlinearSolve.NNlibExtLoaded[] = true
89
return
910
end
1011

12+
# Broyden's method
1113
@views function SciMLBase.__solve(prob::NonlinearProblem,
1214
alg::BatchedBroyden;
13-
abstol=nothing,
14-
reltol=nothing,
15-
maxiters=1000,
15+
abstol = nothing,
16+
reltol = nothing,
17+
maxiters = 1000,
1618
kwargs...)
1719
iip = isinplace(prob)
1820

@@ -24,7 +26,7 @@ end
2426

2527
storage = _get_storage(mode, u)
2628

27-
xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4)
29+
xₙ, xₙ₋₁, δxₙ, δf = ntuple(_ -> copy(u), 4)
2830
T = eltype(u)
2931

3032
atol = _get_tolerance(abstol, tc.abstol, T)
@@ -41,16 +43,16 @@ end
4143
xₙ .= xₙ₋₁ .- 𝓙⁻¹f
4244

4345
@maybeinplace iip fₙ=f(xₙ)
44-
δx .= xₙ .- xₙ₋₁
46+
δxₙ .= xₙ .- xₙ₋₁
4547
δf .= fₙ .- fₙ₋₁
4648

4749
batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N))
48-
δxᵀ = reshape(δx, 1, L, N)
50+
δxₙᵀ = reshape(δxₙ, 1, L, N)
4951

50-
batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N))
51-
batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹)
52-
δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
53-
batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))
52+
batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxₙᵀ, reshape(𝓙⁻¹f, L, 1, N))
53+
batched_mul!(xᵀ𝓙⁻¹, δxₙᵀ, 𝓙⁻¹)
54+
δxₙ .= (δxₙ .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
55+
batched_mul!(𝓙⁻¹, reshape(δxₙ, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))
5456

5557
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
5658
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
@@ -74,7 +76,103 @@ end
7476
alg,
7577
reconstruct(xₙ),
7678
reconstruct(fₙ);
77-
retcode=ReturnCode.MaxIters)
79+
retcode = ReturnCode.MaxIters)
80+
end
81+
82+
# Limited Memory Broyden's method
83+
@views function SciMLBase.__solve(prob::NonlinearProblem,
84+
alg::BatchedLBroyden;
85+
abstol = nothing,
86+
reltol = nothing,
87+
maxiters = 1000,
88+
kwargs...)
89+
iip = isinplace(prob)
90+
91+
u, f, reconstruct = _construct_batched_problem_structure(prob)
92+
L, N = size(u)
93+
T = eltype(u)
94+
95+
tc = alg.termination_condition
96+
mode = DiffEqBase.get_termination_mode(tc)
97+
98+
storage = _get_storage(mode, u)
99+
100+
η = min(maxiters, alg.threshold)
101+
U = fill!(similar(u, (η, L, N)), zero(T))
102+
Vᵀ = fill!(similar(u, (L, η, N)), zero(T))
103+
104+
xₙ, xₙ₋₁, δfₙ = ntuple(_ -> copy(u), 3)
105+
106+
atol = _get_tolerance(abstol, tc.abstol, T)
107+
rtol = _get_tolerance(reltol, tc.reltol, T)
108+
termination_condition = tc(storage)
109+
110+
@maybeinplace iip fₙ₋₁=f(xₙ) u
111+
iip && (fₙ = copy(fₙ₋₁))
112+
δxₙ = -copy(fₙ₋₁)
113+
ηNx = similar(xₙ, η, N)
114+
115+
for i in 1:maxiters
116+
@. xₙ = xₙ₋₁ - δxₙ
117+
@maybeinplace iip fₙ=f(xₙ)
118+
@. δxₙ = xₙ - xₙ₋₁
119+
@. δfₙ = fₙ - fₙ₋₁
120+
121+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
122+
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
123+
return DiffEqBase.build_solution(prob,
124+
alg,
125+
reconstruct(xₙ),
126+
reconstruct(fₙ);
127+
retcode)
128+
end
129+
130+
_L = min(i, η)
131+
_U = U[1:_L, :, :]
132+
_Vᵀ = Vᵀ[:, 1:_L, :]
133+
134+
idx = mod1(i, η)
135+
136+
if i > 1
137+
partial_ηNx = ηNx[1:_L, :]
138+
139+
_ηNx = reshape(partial_ηNx, 1, :, N)
140+
batched_mul!(_ηNx, reshape(δxₙ, 1, L, N), _Vᵀ)
141+
batched_mul!(Vᵀ[:, idx:idx, :], _ηNx, _U)
142+
Vᵀ[:, idx, :] .-= δxₙ
143+
144+
_ηNx = reshape(partial_ηNx, :, 1, N)
145+
batched_mul!(_ηNx, _U, reshape(δfₙ, L, 1, N))
146+
batched_mul!(U[idx:idx, :, :], _Vᵀ, _ηNx)
147+
U[idx, :, :] .-= δfₙ
148+
else
149+
Vᵀ[:, idx, :] .= -δxₙ
150+
U[idx, :, :] .= -δfₙ
151+
end
152+
153+
U[idx, :, :] .= (δxₙ .- U[idx, :, :]) ./
154+
(sum(Vᵀ[:, idx, :] .* δfₙ; dims = 1) .+
155+
convert(T, 1e-5))
156+
157+
_L = min(i + 1, η)
158+
_ηNx = reshape(ηNx[1:_L, :], :, 1, N)
159+
batched_mul!(_ηNx, U[1:_L, :, :], reshape(δfₙ, L, 1, N))
160+
batched_mul!(reshape(δxₙ, L, 1, N), Vᵀ[:, 1:_L, :], _ηNx)
161+
162+
xₙ₋₁ .= xₙ
163+
fₙ₋₁ .= fₙ
164+
end
165+
166+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
167+
xₙ = storage.u
168+
@maybeinplace iip fₙ=f(xₙ)
169+
end
170+
171+
return DiffEqBase.build_solution(prob,
172+
alg,
173+
reconstruct(xₙ),
174+
reconstruct(fₙ);
175+
retcode = ReturnCode.MaxIters)
78176
end
79177

80178
end

src/batched/dfsane.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
1+
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
22
AbstractBatchedNonlinearSolveAlgorithm
33
σₘᵢₙ::T = 1.0f-10
44
σₘₐₓ::T = 1.0f+10

src/batched/lbroyden.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
struct BatchedLBroyden{TC <: NLSolveTerminationCondition} <:
2+
AbstractBatchedNonlinearSolveAlgorithm
3+
termination_condition::TC
4+
threshold::Int
5+
end
6+
7+
# Implementation of solve using Package Extensions

src/broyden.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ end
3030

3131
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
3232
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
33+
if SciMLBase.isinplace(prob)
34+
error("Broyden currently only supports out-of-place nonlinear problems")
35+
end
3336
tc = alg.termination_condition
3437
mode = DiffEqBase.get_termination_mode(tc)
3538
f = Base.Fix2(prob.f, prob.p)
@@ -39,19 +42,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
3942
T = eltype(x)
4043
J⁻¹ = init_J(x)
4144

42-
if SciMLBase.isinplace(prob)
43-
error("Broyden currently only supports out-of-place nonlinear problems")
44-
end
45-
4645
atol = _get_tolerance(abstol, tc.abstol, T)
4746
rtol = _get_tolerance(reltol, tc.reltol, T)
4847

4948
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
5049
error("Broyden currently doesn't support SAFE_BEST termination modes")
5150
end
5251

53-
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
54-
nothing
52+
storage = _get_storage(mode, x)
5553
termination_condition = tc(storage)
5654

5755
xₙ = x

0 commit comments

Comments
 (0)