From c1239ca2cde7e650f51ab94f91b4a3e90df8897a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Jun 2023 11:21:41 -0400 Subject: [PATCH 01/12] Use PackageExtensionCompat --- Project.toml | 10 +++++----- src/SimpleNonlinearSolve.jl | 11 ++--------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index cf5a675..d255103 100644 --- a/Project.toml +++ b/Project.toml @@ -9,10 +9,10 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] @@ -23,14 +23,14 @@ SimpleBatchedNonlinearSolveExt = "NNlib" [compat] ArrayInterface = "6, 7" -DiffEqBase = "6.123.0" +DiffEqBase = "6.126" FiniteDiff = "2" ForwardDiff = "0.10.3" NNlib = "0.8" +PackageExtensionCompat = "1" +PrecompileTools = "1" Reexport = "0.2, 1" -Requires = "1" SciMLBase = "1.73" -PrecompileTools = "1" StaticArraysCore = "1.4" julia = "1.6" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 8749aa7..97ce203 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -10,16 +10,9 @@ using DiffEqBase @reexport using SciMLBase -if !isdefined(Base, :get_extension) - using Requires -end - +using PackageExtensionCompat function __init__() - @static if !isdefined(Base, :get_extension) - @require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin - include("../ext/SimpleBatchedNonlinearSolveExt.jl") - end - end + @require_extensions end abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end From b1563f0434cc049cd37e3921d83260beede070ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Jun 2023 11:51:36 -0400 Subject: [PATCH 02/12] Rework BatchedBroyden to be more efficient --- Project.toml | 4 +- ext/SimpleBatchedNonlinearSolveExt.jl | 120 ++++++++++++-------------- src/SimpleNonlinearSolve.jl | 9 ++ src/batched/broyden.jl | 20 +++++ src/batched/dfsane.jl | 0 src/batched/lbroyden.jl | 0 src/batched/raphson.jl | 0 src/batched/utils.jl | 78 +++++++++++++++++ src/broyden.jl | 16 ++-- test/basictests.jl | 12 +-- 10 files changed, 176 insertions(+), 83 deletions(-) create mode 100644 src/batched/broyden.jl create mode 100644 src/batched/dfsane.jl create mode 100644 src/batched/lbroyden.jl create mode 100644 src/batched/raphson.jl create mode 100644 src/batched/utils.jl diff --git a/Project.toml b/Project.toml index d255103..b1b3a42 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.16" +version = "0.1.17" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -26,7 +26,7 @@ ArrayInterface = "6, 7" DiffEqBase = "6.126" FiniteDiff = "2" ForwardDiff = "0.10.3" -NNlib = "0.8" +NNlib = "0.8, 0.9" PackageExtensionCompat = "1" PrecompileTools = "1" Reexport = "0.2, 1" diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt.jl index dd76288..9e1ba59 100644 --- a/ext/SimpleBatchedNonlinearSolveExt.jl +++ b/ext/SimpleBatchedNonlinearSolveExt.jl @@ -1,90 +1,76 @@ module SimpleBatchedNonlinearSolveExt -using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase +using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase, NNlib +import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace -isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) +@views function SciMLBase.__solve(prob::NonlinearProblem, + alg::BatchedBroyden; + abstol=nothing, + reltol=nothing, + maxiters=1000, + kwargs...) + iip = isinplace(prob) + u0 = prob.u0 -_batch_transpose(x) = reshape(x, 1, size(x)...) + u, f, reconstruct = _construct_batched_problem_structure(prob) + L, N = size(u) -_batched_mul(x, y) = x * y - -function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T} - return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) -end - -function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T} - return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) -end - -function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2} - return batched_mul(x, y) -end - -function _init_J_batched(x::AbstractMatrix{T}) where {T} - J = ArrayInterface.zeromatrix(x[:, 1]) - if ismutable(x) - J[diagind(J)] .= one(eltype(x)) - else - J += I - end - return repeat(J, 1, 1, size(x, 2)) -end - -function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) - f = Base.Fix2(prob.f, prob.p) - x = float(prob.u0) - if ndims(x) != 2 - error("`batch` mode works only if `ndims(prob.u0) == 2`") - end + storage = _get_storage(mode, u) - fā‚™ = f(x) - T = eltype(x) - J⁻¹ = _init_J_batched(x) + xā‚™, xₙ₋₁, Ī“x, Ī“f = ntuple(_ -> copy(u), 4) + T = eltype(u) - if SciMLBase.isinplace(prob) - error("Broyden currently only supports out-of-place nonlinear problems") - end + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + š“™ā»Ā¹ = _init_š“™(xā‚™) # L Ɨ L Ɨ N + š“™ā»Ā¹f, xįµ€š“™ā»Ā¹Ī“f, xįµ€š“™ā»Ā¹ = similar(š“™ā»Ā¹, L, N), similar(š“™ā»Ā¹, 1, N), similar(š“™ā»Ā¹, 1, L, N) - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - error("Broyden currently doesn't support SAFE_BEST termination modes") - end + @maybeinplace iip fₙ₋₁=f(xā‚™) u + iip && (fā‚™ = copy(fₙ₋₁)) + for n in 1:maxiters + batched_mul!(reshape(š“™ā»Ā¹f, L, 1, N), š“™ā»Ā¹, reshape(fₙ₋₁, L, 1, N)) + xā‚™ .= xₙ₋₁ .- š“™ā»Ā¹f - storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : - nothing - termination_condition = tc(storage) + @maybeinplace iip fā‚™=f(xā‚™) + Ī“x .= xā‚™ .- xₙ₋₁ + Ī“f .= fā‚™ .- fₙ₋₁ + + batched_mul!(reshape(š“™ā»Ā¹f, L, 1, N), š“™ā»Ā¹, reshape(Ī“f, L, 1, N)) + Ī“xįµ€ = reshape(Ī“x, 1, L, N) - xā‚™ = x - xₙ₋₁ = x - fₙ₋₁ = fā‚™ - for i in 1:maxiters - xā‚™ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁) - fā‚™ = f(xā‚™) - Ī”xā‚™ = xā‚™ .- xₙ₋₁ - Ī”fā‚™ = fā‚™ .- fₙ₋₁ - J⁻¹Δfā‚™ = _batched_mul(J⁻¹, Ī”fā‚™) - J⁻¹ += _batched_mul(((Ī”xā‚™ .- J⁻¹Δfā‚™) ./ - (_batched_mul(_batch_transpose(Ī”xā‚™), J⁻¹Δfā‚™) .+ T(1e-5))), - _batched_mul(_batch_transpose(Ī”xā‚™), J⁻¹)) + batched_mul!(reshape(xįµ€š“™ā»Ā¹Ī“f, 1, 1, N), Ī“xįµ€, reshape(š“™ā»Ā¹f, L, 1, N)) + batched_mul!(xįµ€š“™ā»Ā¹, Ī“xįµ€, š“™ā»Ā¹) + Ī“x .= (Ī“x .- š“™ā»Ā¹f) ./ (xįµ€š“™ā»Ā¹Ī“f .+ T(1e-5)) + batched_mul!(š“™ā»Ā¹, reshape(Ī“x, L, 1, N), xįµ€š“™ā»Ā¹, one(T), one(T)) if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) - return SciMLBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.Success) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode) end - xₙ₋₁ = xā‚™ - fₙ₋₁ = fā‚™ + xₙ₋₁ .= xā‚™ + fₙ₋₁ .= fā‚™ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xā‚™ = storage.u + @maybeinplace iip fā‚™=f(xā‚™) end - return SciMLBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.MaxIters) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode=ReturnCode.MaxIters) end end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 97ce203..ab20336 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -19,6 +19,7 @@ abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonline abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end +abstract type AbstractBatchedNonlinearSolveAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end include("utils.jl") include("bisection.jl") @@ -35,6 +36,13 @@ include("ad.jl") include("halley.jl") include("alefeld.jl") +# Batched Solver Support +include("batched/utils.jl") +include("batched/raphson.jl") +include("batched/dfsane.jl") +include("batched/broyden.jl") +include("batched/lbroyden.jl") + import PrecompileTools PrecompileTools.@compile_workload begin @@ -67,5 +75,6 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld +export BatchedBroyden end # module diff --git a/src/batched/broyden.jl b/src/batched/broyden.jl new file mode 100644 index 0000000..a301c4e --- /dev/null +++ b/src/batched/broyden.jl @@ -0,0 +1,20 @@ +""" + BatchedBroyden(; + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + +A low-overhead batched implementation of Broyden capable of solving multiple nonlinear +problems simultaneously. + +!!! note + + To use this version, remember to load `NNlib`, i.e., `using NNlib` or + `import NNlib` must be present in your code. +""" +struct BatchedBroyden{TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + termination_condition::TC +end + +# Implementation of solve using Package Extensions diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl new file mode 100644 index 0000000..e69de29 diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl new file mode 100644 index 0000000..e69de29 diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl new file mode 100644 index 0000000..e69de29 diff --git a/src/batched/utils.jl b/src/batched/utils.jl new file mode 100644 index 0000000..9ed9134 --- /dev/null +++ b/src/batched/utils.jl @@ -0,0 +1,78 @@ +function _get_tolerance(Ī·, tc_Ī·, ::Type{T}) where {T} + fallback_Ī· = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) + return ifelse(Ī· !== nothing, Ī·, ifelse(tc_Ī· !== nothing, tc_Ī·, fallback_Ī·)) +end + +function _construct_batched_problem_structure(prob) + return _construct_batched_problem_structure(prob.u0, + prob.f, + prob.p, + Val(SciMLBase.isinplace(prob))) +end + +function _construct_batched_problem_structure(u0::AbstractArray{T, N}, + f, + p, + ::Val{iip}) where {T, N, iip} + # Reconstruct `u` + reconstruct = N == 2 ? identity : Base.Fix2(reshape, size(u0)) + # Standardize `u` + standardize = N == 2 ? identity : + (N == 1 ? Base.Fix2(reshape, (:, 1)) : + Base.Fix2(reshape, (:, size(u0, ndims(u0))))) + # Updated Function + f_modified = if iip + function f_modified_iip(du, u) + f(reconstruct(du), reconstruct(u), p) + return standardize(du) + end + else + f_modified_oop(u) = standardize(f(reconstruct(u), p)) + end + return standardize(u0), f_modified, reconstruct +end + +@views function _init_š“™(x::AbstractMatrix) + š“™ = ArrayInterface.zeromatrix(x[:, 1]) + if ismutable(x) + š“™[diagind(š“™)] .= one(eltype(x)) + else + š“™ .+= I + end + return repeat(š“™, 1, 1, size(x, 2)) +end + +_result_from_storage(::Nothing, xā‚™, fā‚™, f, mode) = ReturnCode.Success, xā‚™, fā‚™ +function _result_from_storage(storage::NLSolveSafeTerminationResult, xā‚™, fā‚™, f, mode) + if storage.return_code == DiffEqBase.NLSolveSafeTerminationReturnCode.Success + return ReturnCode.Success, xā‚™, fā‚™ + else + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + return ReturnCode.Terminated, storage.u, f(storage.u) + else + return ReturnCode.Terminated, xā‚™, fā‚™ + end + end +end + +function _get_storage(mode, u) + return mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? + NLSolveSafeTerminationResult(mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES ? u : + nothing) : nothing +end + +macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing) + @assert expr.head == :(=) + x1, x2 = expr.args + @assert x2.head == :call + f, x = x2.args + define_expr = u0 === nothing ? :() : :($(x1) = similar($(u0))) + return quote + if $(esc(iip)) + $(esc(define_expr)) + $(esc(f))($(esc(x1)), $(esc(x))) + else + $(esc(expr)) + end + end +end diff --git a/src/broyden.jl b/src/broyden.jl index 8ce0d66..56fa2ef 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -11,19 +11,19 @@ and static array problems. To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or `import NNlib` must be present in your code. """ -struct Broyden{batched, TC <: NLSolveTerminationCondition} <: +struct Broyden{TC <: NLSolveTerminationCondition} <: AbstractSimpleNonlinearSolveAlgorithm termination_condition::TC +end - function Broyden(; batched = false, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - return new{batched, typeof(termination_condition)}(termination_condition) - end +function Broyden(; batched = false, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + return (batched ? BatchedBroyden : Broyden)(termination_condition) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) diff --git a/test/basictests.jl b/test/basictests.jl index aea1731..dd54527 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -5,12 +5,12 @@ using DiffEqBase using LinearAlgebra using Test -const BATCHED_BROYDEN_SOLVERS = Broyden[] -const BROYDEN_SOLVERS = Broyden[] -const BATCHED_LBROYDEN_SOLVERS = LBroyden[] -const LBROYDEN_SOLVERS = LBroyden[] -const BATCHED_DFSANE_SOLVERS = SimpleDFSane[] -const DFSANE_SOLVERS = SimpleDFSane[] +const BATCHED_BROYDEN_SOLVERS = [] +const BROYDEN_SOLVERS = [] +const BATCHED_LBROYDEN_SOLVERS = [] +const LBROYDEN_SOLVERS = [] +const BATCHED_DFSANE_SOLVERS = [] +const DFSANE_SOLVERS = [] for mode in instances(NLSolveTerminationMode.T) if mode ∈ From 9f09d65c471004e79d5c4b3d34b85e0c884f5fd8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Jun 2023 15:29:45 -0400 Subject: [PATCH 03/12] Add Batching to Newton Raphson --- Project.toml | 9 +- ext/SimpleNonlinearSolveADLinearSolveExt.jl | 89 +++++++++++++++++++ ...Ext.jl => SimpleNonlinearSolveNNlibExt.jl} | 12 ++- src/SimpleNonlinearSolve.jl | 6 +- src/batched/raphson.jl | 24 +++++ src/batched/utils.jl | 39 ++++---- src/broyden.jl | 11 ++- src/raphson.jl | 42 +++++++-- 8 files changed, 195 insertions(+), 37 deletions(-) create mode 100644 ext/SimpleNonlinearSolveADLinearSolveExt.jl rename ext/{SimpleBatchedNonlinearSolveExt.jl => SimpleNonlinearSolveNNlibExt.jl} (90%) diff --git a/Project.toml b/Project.toml index b1b3a42..9017085 100644 --- a/Project.toml +++ b/Project.toml @@ -16,16 +16,21 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [extensions] -SimpleBatchedNonlinearSolveExt = "NNlib" +SimpleNonlinearSolveNNlibExt = "NNlib" +SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"] [compat] +AbstractDifferentiation = "0.5" ArrayInterface = "6, 7" DiffEqBase = "6.126" FiniteDiff = "2" ForwardDiff = "0.10.3" +LinearSolve = "2" NNlib = "0.8, 0.9" PackageExtensionCompat = "1" PrecompileTools = "1" @@ -35,7 +40,9 @@ StaticArraysCore = "1.4" julia = "1.6" [extras] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/ext/SimpleNonlinearSolveADLinearSolveExt.jl b/ext/SimpleNonlinearSolveADLinearSolveExt.jl new file mode 100644 index 0000000..96ba916 --- /dev/null +++ b/ext/SimpleNonlinearSolveADLinearSolveExt.jl @@ -0,0 +1,89 @@ +module SimpleNonlinearSolveADLinearSolveExt + +using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, + SimpleNonlinearSolve, SciMLBase +import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace + +const AD = AbstractDifferentiation + +function __init__() + SimpleNonlinearSolve.ADLinearSolveExtLoaded[] = true + return +end + +function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + # TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl + chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size + ad = SciMLBase._unwrap_val(autodiff) ? + AD.ForwardDiffBackend(; chunksize) : + AD.FiniteDifferencesBackend() + return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}( + ad, + nothing, + termination_condition) +end + +function SciMLBase.__solve(prob::NonlinearProblem, + alg::SimpleBatchedNewtonRaphson; + abstol=nothing, + reltol=nothing, + maxiters=1000, + kwargs...) + iip = isinplace(prob) + @assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems." + u, f, reconstruct = _construct_batched_problem_structure(prob) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + xā‚™, xₙ₋₁, Ī“x = copy(u), copy(u), copy(u) + T = eltype(u) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + for i in 1:maxiters + fā‚™, (š“™,) = AD.value_and_jacobian(alg.autodiff, f, xā‚™) + + iszero(fā‚™) && return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode=ReturnCode.Success) + + solve(LinearProblem(š“™, vec(fā‚™); u0=vec(Ī“x)), alg.linsolve; kwargs...) + xā‚™ .-= Ī“x + + if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode) + end + + xₙ₋₁ .= xā‚™ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xā‚™ = storage.u + fā‚™ = f(xā‚™) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode=ReturnCode.MaxIters) +end + +end diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl similarity index 90% rename from ext/SimpleBatchedNonlinearSolveExt.jl rename to ext/SimpleNonlinearSolveNNlibExt.jl index 9e1ba59..e62e2bd 100644 --- a/ext/SimpleBatchedNonlinearSolveExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -1,8 +1,13 @@ -module SimpleBatchedNonlinearSolveExt +module SimpleNonlinearSolveNNlibExt -using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase, NNlib +using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace +function __init__() + SimpleNonlinearSolve.NNlibExtLoaded[] = true + return +end + @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; abstol=nothing, @@ -10,7 +15,6 @@ import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, maxiters=1000, kwargs...) iip = isinplace(prob) - u0 = prob.u0 u, f, reconstruct = _construct_batched_problem_structure(prob) L, N = size(u) @@ -49,7 +53,7 @@ import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, batched_mul!(š“™ā»Ā¹, reshape(Ī“x, L, 1, N), xįµ€š“™ā»Ā¹, one(T), one(T)) if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) - retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) return DiffEqBase.build_solution(prob, alg, reconstruct(xā‚™), diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index ab20336..49aef9f 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -15,6 +15,9 @@ function __init__() @require_extensions end +const ADLinearSolveExtLoaded = Ref{Bool}(false) +const NNlibExtLoaded = Ref{Bool}(false) + abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end @@ -75,6 +78,7 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld -export BatchedBroyden +export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane, + BatchedLBroyden end # module diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl index e69de29..0a28e33 100644 --- a/src/batched/raphson.jl +++ b/src/batched/raphson.jl @@ -0,0 +1,24 @@ +""" + SimpleBatchedNewtonRaphson(; chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + +A low-overhead batched implementation of Newton-Raphson capable of solving multiple +nonlinear problems simultaneously. + +!!! note + + To use the `batched` version, remember to load `AbstractDifferentiation` and + `LinearSolve`. +""" +struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + autodiff::AD + linsolve::LS + termination_condition::TC +end + +# Implementation of solve using Package Extensions diff --git a/src/batched/utils.jl b/src/batched/utils.jl index 9ed9134..33441e9 100644 --- a/src/batched/utils.jl +++ b/src/batched/utils.jl @@ -1,3 +1,19 @@ +macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing) + @assert expr.head == :(=) + x1, x2 = expr.args + @assert x2.head == :call + f, x = x2.args + define_expr = u0 === nothing ? :() : :($(x1) = similar($(u0))) + return quote + if $(esc(iip)) + $(esc(define_expr)) + $(esc(f))($(esc(x1)), $(esc(x))) + else + $(esc(expr)) + end + end +end + function _get_tolerance(Ī·, tc_Ī·, ::Type{T}) where {T} fallback_Ī· = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) return ifelse(Ī· !== nothing, Ī·, ifelse(tc_Ī· !== nothing, tc_Ī·, fallback_Ī·)) @@ -42,13 +58,14 @@ end return repeat(š“™, 1, 1, size(x, 2)) end -_result_from_storage(::Nothing, xā‚™, fā‚™, f, mode) = ReturnCode.Success, xā‚™, fā‚™ -function _result_from_storage(storage::NLSolveSafeTerminationResult, xā‚™, fā‚™, f, mode) +_result_from_storage(::Nothing, xā‚™, fā‚™, args...) = ReturnCode.Success, xā‚™, fā‚™ +function _result_from_storage(storage::NLSolveSafeTerminationResult, xā‚™, fā‚™, f, mode, iip) if storage.return_code == DiffEqBase.NLSolveSafeTerminationReturnCode.Success return ReturnCode.Success, xā‚™, fā‚™ else if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - return ReturnCode.Terminated, storage.u, f(storage.u) + @maybeinplace iip fā‚™ = f(xā‚™) + return ReturnCode.Terminated, storage.u, fā‚™ else return ReturnCode.Terminated, xā‚™, fā‚™ end @@ -60,19 +77,3 @@ function _get_storage(mode, u) NLSolveSafeTerminationResult(mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES ? u : nothing) : nothing end - -macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing) - @assert expr.head == :(=) - x1, x2 = expr.args - @assert x2.head == :call - f, x = x2.args - define_expr = u0 === nothing ? :() : :($(x1) = similar($(u0))) - return quote - if $(esc(iip)) - $(esc(define_expr)) - $(esc(f))($(esc(x1)), $(esc(x))) - else - $(esc(expr)) - end - end -end diff --git a/src/broyden.jl b/src/broyden.jl index 56fa2ef..74566e8 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,7 +1,8 @@ """ Broyden(; batched = false, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, reltol = nothing)) + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. @@ -20,7 +21,11 @@ function Broyden(; batched = false, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; abstol = nothing, reltol = nothing)) - return (batched ? BatchedBroyden : Broyden)(termination_condition) + if batched + @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." + return BatchedBroyden(termination_condition) + end + return Broyden(termination_condition) end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; diff --git a/src/raphson.jl b/src/raphson.jl index 386c350..59eb1ed 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,8 +1,8 @@ """ -```julia -SimpleNewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(), - diff_type = Val{:forward}) -``` + SimpleNewtonRaphson(; batched = false, + chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}) A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar and static array problems. @@ -27,13 +27,37 @@ and static array problems. - `diff_type`: the type of finite differencing used if `autodiff = false`. Defaults to `Val{:forward}` for forward finite differences. For more details on the choices, see the [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) documentation. + +!!! note + + To use the `batched` version, remember to load `AbstractDifferentiation` and + `LinearSolve`. """ -struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} - function SimpleNewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(), - diff_type = Val{:forward}) - new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff), - SciMLBase._unwrap_val(diff_type)}() +struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end + +function SimpleNewtonRaphson(; batched=false, + chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = missing) + if !ismissing(termination_condition) && !batched + throw(ArgumentError("`termination_condition` is currently only supported for batched problems")) + end + if batched + @assert ADLinearSolveExtLoaded[] "Please install and load `LinearSolve.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." + termination_condition = ismissing(termination_condition) ? + NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing) : + termination_condition + return SimpleBatchedNewtonRaphson(; chunk_size, + autodiff, + diff_type, + termination_condition) end + return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), + SciMLBase._unwrap_val(autodiff), + SciMLBase._unwrap_val(diff_type)}() end function SciMLBase.__solve(prob::NonlinearProblem, From f4b36b0f8e905ca899baacf13ee0dec12365725a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Jun 2023 16:59:54 -0400 Subject: [PATCH 04/12] Update SimpleDFSane Batching --- Project.toml | 8 --- src/batched/dfsane.jl | 140 ++++++++++++++++++++++++++++++++++++++++++ src/batched/utils.jl | 4 +- src/broyden.jl | 7 +-- src/dfsane.jl | 122 +++++++++++++----------------------- test/Project.toml | 12 ++++ test/basictests.jl | 13 +--- 7 files changed, 201 insertions(+), 105 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 9017085..f1f472d 100644 --- a/Project.toml +++ b/Project.toml @@ -41,13 +41,5 @@ julia = "1.6" [extras] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"] diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index e69de29..fe7cbcd 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -0,0 +1,140 @@ +@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + Ļƒā‚˜įµ¢ā‚™::T = 1.0f-10 + Ļƒā‚˜ā‚ā‚“::T = 1.0f+10 + Ļƒā‚::T = 1.0f0 + M::Int = 10 + γ::T = 1.0f-4 + Ļ„ā‚˜įµ¢ā‚™::T = 0.1f0 + Ļ„ā‚˜ā‚ā‚“::T = 0.5f0 + nā‚‘ā‚“ā‚š::Int = 2 + Ī·ā‚›::F = (fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚, n, xā‚™, fā‚™) -> fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚ ./ n .^ 2 + termination_condition::TC = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol=nothing, + reltol=nothing) + max_inner_iterations::Int = 1000 +end + +function SciMLBase.__solve(prob::NonlinearProblem, + alg::SimpleBatchedDFSane, + args...; + abstol=nothing, + reltol=nothing, + maxiters=100, + kwargs...) + iip = isinplace(prob) + + u, f, reconstruct = _construct_batched_problem_structure(prob) + L, N = size(u) + T = eltype(u) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + Ļƒā‚˜įµ¢ā‚™, Ļƒā‚˜ā‚ā‚“, γ, Ļ„ā‚˜įµ¢ā‚™, Ļ„ā‚˜ā‚ā‚“ = T(alg.Ļƒā‚˜įµ¢ā‚™), T(alg.Ļƒā‚˜ā‚ā‚“), T(alg.γ), T(alg.Ļ„ā‚˜įµ¢ā‚™), T(alg.Ļ„ā‚˜ā‚ā‚“) + α₁ = one(T) + Ī±ā‚Š, α₋ = similar(u, 1, N), similar(u, 1, N) + Ļƒā‚™ = fill(T(alg.Ļƒā‚), 1, N) + š’¹ = similar(Ļƒā‚™, L, N) + (; M, nā‚‘ā‚“ā‚š) = alg + + xā‚™, xₙ₋₁, fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ = copy(u), copy(u), similar(u, 1, N), similar(u, 1, N) + + function ff!(fā‚“, fā‚™ā‚’įµ£ā‚˜, x) + f(fā‚“, x) + sum!(abs2, fā‚™ā‚’įµ£ā‚˜, fā‚“) + fā‚™ā‚’įµ£ā‚˜ .^= (nā‚‘ā‚“ā‚š / 2) + return fā‚“ + end + + function ff!(fā‚™ā‚’įµ£ā‚˜, x) + fā‚“ = f(x) + sum!(abs2, fā‚™ā‚’įµ£ā‚˜, fā‚“) + fā‚™ā‚’įµ£ā‚˜ .^= (nā‚‘ā‚“ā‚š / 2) + return fā‚“ + end + + @maybeinplace iip fₙ₋₁ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, xā‚™) xā‚™ + iip && (fā‚™ = similar(fₙ₋₁)) + ā„‹ = repeat(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, M, 1) + fĢ„ = similar(ā„‹, 1, N) + Ī·ā‚› = (n, xā‚™, fā‚™) -> alg.Ī·ā‚›(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, n, xā‚™, fā‚™) + + for n in 1:maxiters + # Spectral parameter range check + @. Ļƒā‚™ = sign(Ļƒā‚™) * clamp(abs(Ļƒā‚™), Ļƒā‚˜įµ¢ā‚™, Ļƒā‚˜ā‚ā‚“) + + # Line search direction + @. š’¹ = -Ļƒā‚™ * fₙ₋₁ + + Ī· = Ī·ā‚›(n, xₙ₋₁, fₙ₋₁) + maximum!(fĢ„, ā„‹) + fill!(Ī±ā‚Š, α₁) + fill!(α₋, α₁) + @. xā‚™ = xₙ₋₁ + Ī±ā‚Š * š’¹ + + @maybeinplace iip fā‚™ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) + + for _ in 1:(alg.max_inner_iterations) + š’ø = @. fĢ„ + Ī· - γ * Ī±ā‚Š^2 * fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚ + + (sum(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ .≤ š’ø) ≄ N Ć· 2) && break + + @. Ī±ā‚Š = clamp(Ī±ā‚Š^2 * fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚ / (fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ + (T(2) * Ī±ā‚Š - T(1)) * fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚), + Ļ„ā‚˜įµ¢ā‚™ * Ī±ā‚Š, + Ļ„ā‚˜ā‚ā‚“ * Ī±ā‚Š) + @. xā‚™ = xₙ₋₁ - α₋ * š’¹ + @maybeinplace iip fā‚™ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) + + (sum(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ .≤ š’ø) ≄ N Ć· 2) && break + + @. α₋ = clamp(α₋^2 * fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚ / (fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ + (T(2) * α₋ - T(1)) * fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚), + Ļ„ā‚˜įµ¢ā‚™ * α₋, + Ļ„ā‚˜ā‚ā‚“ * α₋) + @. xā‚™ = xₙ₋₁ + Ī±ā‚Š * š’¹ + @maybeinplace iip fā‚™ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) + end + + if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode) + end + + # Update spectral parameter + @. xₙ₋₁ = xā‚™ - xₙ₋₁ + @. fₙ₋₁ = fā‚™ - fₙ₋₁ + + sum!(abs2, Ī±ā‚Š, xₙ₋₁) + sum!(α₋, xₙ₋₁ .* fₙ₋₁) + Ļƒā‚™ .= Ī±ā‚Š ./ (α₋ .+ T(1e-5)) + + # Take step + @. xₙ₋₁ = xā‚™ + @. fₙ₋₁ = fā‚™ + @. fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚ = fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ + + # Update history + ā„‹[n % M + 1, :] .= view(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, 1, :) + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xā‚™ = storage.u + @maybeinplace iip fā‚™ = f(xā‚™) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode=ReturnCode.MaxIters) +end diff --git a/src/batched/utils.jl b/src/batched/utils.jl index 33441e9..b5dbd59 100644 --- a/src/batched/utils.jl +++ b/src/batched/utils.jl @@ -2,12 +2,12 @@ macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing) @assert expr.head == :(=) x1, x2 = expr.args @assert x2.head == :call - f, x = x2.args + f, x... = x2.args define_expr = u0 === nothing ? :() : :($(x1) = similar($(u0))) return quote if $(esc(iip)) $(esc(define_expr)) - $(esc(f))($(esc(x1)), $(esc(x))) + $(esc(f))($(esc(x1)), $(esc.(x)...)) else $(esc(expr)) end diff --git a/src/broyden.jl b/src/broyden.jl index 74566e8..6c5c3ce 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -43,11 +43,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; error("Broyden currently only supports out-of-place nonlinear problems") end - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES error("Broyden currently doesn't support SAFE_BEST termination modes") diff --git a/src/dfsane.jl b/src/dfsane.jl index b5e2b82..d4bc777 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -51,7 +51,7 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_ - `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the algorithm. Used exclusively in `batched` mode. Defaults to `1000`. """ -struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm +struct SimpleDFSane{T, TC} <: AbstractSimpleNonlinearSolveAlgorithm σ_min::T σ_max::T σ_1::T @@ -62,47 +62,54 @@ struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm nexp::Int Ī·_strategy::Function termination_condition::TC - max_inner_iterations::Int +end - function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, - M::Int = 10, γ::Real = 1e-4, Ļ„_min::Real = 0.1, Ļ„_max::Real = 0.5, - nexp::Int = 2, Ī·_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing), - batched::Bool = false, - max_inner_iterations = 1000) - return new{batched, typeof(σ_min), typeof(termination_condition)}(σ_min, - σ_max, - σ_1, +function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, + M::Int = 10, γ::Real = 1e-4, Ļ„_min::Real = 0.1, Ļ„_max::Real = 0.5, + nexp::Int = 2, Ī·_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), + batched::Bool = false, + max_inner_iterations = 1000) + if batched + return SimpleBatchedDFSane(; Ļƒā‚˜įµ¢ā‚™ = σ_min, + Ļƒā‚˜ā‚ā‚“ = σ_max, + Ļƒā‚ = σ_1, M, γ, - Ļ„_min, - Ļ„_max, - nexp, - Ī·_strategy, + Ļ„ā‚˜įµ¢ā‚™ = Ļ„_min, + Ļ„ā‚˜ā‚ā‚“ = Ļ„_max, + nā‚‘ā‚“ā‚š = nexp, + Ī·ā‚› = Ī·_strategy, termination_condition, max_inner_iterations) end + return SimpleDFSane{typeof(σ_min), typeof(termination_condition)}(σ_min, + σ_max, + σ_1, + M, + γ, + Ļ„_min, + Ļ„_max, + nexp, + Ī·_strategy, + termination_condition) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, +function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - kwargs...) where {batched} + kwargs...) tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) - if batched - batch_size = size(x, 2) - end - T = eltype(x) σ_min = float(alg.σ_min) σ_max = float(alg.σ_max) - σ_k = batched ? fill(float(alg.σ_1), 1, batch_size) : float(alg.σ_1) + σ_k = float(alg.σ_1) M = alg.M γ = float(alg.γ) @@ -111,17 +118,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, nexp = alg.nexp Ī·_strategy = alg.Ī·_strategy - batched && @assert ndims(x)==2 "Batched SimpleDFSane only supports 2D arrays" - if SciMLBase.isinplace(prob) error("SimpleDFSane currently only supports out-of-place nonlinear problems") end - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES error("SimpleDFSane currently doesn't support SAFE_BEST termination modes") @@ -133,22 +135,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, function ff(x) F = f(x) - f_k = if batched - sum(abs2, F; dims = 1) .^ (nexp / 2) - else - norm(F)^nexp - end + f_k = norm(F)^nexp return f_k, F end function generate_history(f_k, M) - if batched - history = similar(f_k, (M, length(f_k))) - history .= reshape(f_k, 1, :) - return history - else - return fill(f_k, M) - end + return fill(f_k, M) end f_k, F_k = ff(x) @@ -158,17 +150,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, for k in 1:maxiters # Spectral parameter range check - if batched - @. σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) - else - σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) - end + σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) # Line search direction d = -σ_k .* F_k Ī· = Ī·_strategy(f_1, k, x, F_k) - fĢ„ = batched ? maximum(history_f_k; dims = 1) : maximum(history_f_k) + fĢ„ = maximum(history_f_k) α_p = α_1 α_m = α_1 x_new = @. x + α_p * d @@ -179,38 +167,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, while true inner_iterations += 1 - if batched - criteria = @. fĢ„ + Ī· - γ * α_p^2 * f_k - # NOTE: This is simply a heuristic, ideally we check using `all` but that is - # typically very expensive for large problems - (sum(f_new .≤ criteria) ≄ batch_size Ć· 2) && break - else - criteria = fĢ„ + Ī· - γ * α_p^2 * f_k - f_new ≤ criteria && break - end + criteria = fĢ„ + Ī· - γ * α_p^2 * f_k + f_new ≤ criteria && break α_tp = @. α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k) x_new = @. x - α_m * d f_new, F_new = ff(x_new) - if batched - # NOTE: This is simply a heuristic, ideally we check using `all` but that is - # typically very expensive for large problems - (sum(f_new .≤ criteria) ≄ batch_size Ć· 2) && break - else - f_new ≤ criteria && break - end + f_new ≤ criteria && break α_tm = @. α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k) α_p = @. clamp(α_tp, Ļ„_min * α_p, Ļ„_max * α_p) α_m = @. clamp(α_tm, Ļ„_min * α_m, Ļ„_max * α_m) x_new = @. x + α_p * d f_new, F_new = ff(x_new) - - # NOTE: The original algorithm runs till either condition is satisfied, however, - # for most batched problems like neural networks we only care about - # approximate convergence - batched && (inner_iterations ≄ alg.max_inner_iterations) && break end if termination_condition(F_new, x_new, x, atol, rtol) @@ -225,11 +195,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, s_k = @. x_new - x y_k = @. F_new - F_k - if batched - σ_k = sum(abs2, s_k; dims = 1) ./ (sum(s_k .* y_k; dims = 1) .+ T(1e-5)) - else - σ_k = (s_k' * s_k) / (s_k' * y_k) - end + σ_k = (s_k' * s_k) / (s_k' * y_k) # Take step x = x_new @@ -237,11 +203,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, f_k = f_new # Store function value - if batched - history_f_k[k % M + 1, :] .= vec(f_new) - else - history_f_k[k % M + 1] = f_new - end + history_f_k[k % M + 1] = f_new end return SciMLBase.build_solution(prob, alg, x, F_k; retcode = ReturnCode.MaxIters) end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..280bf2a --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,12 @@ +[deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/basictests.jl b/test/basictests.jl index dd54527..4fe316b 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,9 +1,5 @@ -using SimpleNonlinearSolve -using StaticArrays -using BenchmarkTools -using DiffEqBase -using LinearAlgebra -using Test +using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, + NNlib, AbstractDifferentiation, LinearSolve const BATCHED_BROYDEN_SOLVERS = [] const BROYDEN_SOLVERS = [] @@ -476,9 +472,6 @@ for options in list_of_options @test all(abs.(f(u, p)) .< 1e-10) end -# Batched Broyden -using NNlib - f, u0 = (u, p) -> u .* u .- p, randn(1, 3) p = [2.0 1.0 5.0]; @@ -488,7 +481,7 @@ sol = solve(probN, Broyden(batched = true)) @test abs.(sol.u) ā‰ˆ sqrt.(p) -for alg in (BATCHED_BROYDEN_SOLVERS..., +@testset "Batched Solver: $(nameof(typeof(alg)))" for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS..., BATCHED_DFSANE_SOLVERS...) sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3) From 947a09b9b95895380c19fe326160fd05fd4b3137 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Jun 2023 17:01:35 -0400 Subject: [PATCH 05/12] Remove docstrings to ensure users don't directly construct the batched solvers --- src/batched/broyden.jl | 14 -------------- src/batched/raphson.jl | 16 ---------------- 2 files changed, 30 deletions(-) diff --git a/src/batched/broyden.jl b/src/batched/broyden.jl index a301c4e..8575488 100644 --- a/src/batched/broyden.jl +++ b/src/batched/broyden.jl @@ -1,17 +1,3 @@ -""" - BatchedBroyden(; - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - -A low-overhead batched implementation of Broyden capable of solving multiple nonlinear -problems simultaneously. - -!!! note - - To use this version, remember to load `NNlib`, i.e., `using NNlib` or - `import NNlib` must be present in your code. -""" struct BatchedBroyden{TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm termination_condition::TC diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl index 0a28e33..2be565f 100644 --- a/src/batched/raphson.jl +++ b/src/batched/raphson.jl @@ -1,19 +1,3 @@ -""" - SimpleBatchedNewtonRaphson(; chunk_size = Val{0}(), - autodiff = Val{true}(), - diff_type = Val{:forward}, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - -A low-overhead batched implementation of Newton-Raphson capable of solving multiple -nonlinear problems simultaneously. - -!!! note - - To use the `batched` version, remember to load `AbstractDifferentiation` and - `LinearSolve`. -""" struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm autodiff::AD From e90573724d401d51eb2eebd08e5210d0199e17e4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 27 Jun 2023 18:38:35 -0400 Subject: [PATCH 06/12] LBroyden --- ext/SimpleNonlinearSolveADLinearSolveExt.jl | 25 +-- ext/SimpleNonlinearSolveNNlibExt.jl | 122 +++++++++++++-- src/batched/dfsane.jl | 2 +- src/batched/lbroyden.jl | 7 + src/broyden.jl | 10 +- src/lbroyden.jl | 163 +++++++++----------- 6 files changed, 210 insertions(+), 119 deletions(-) diff --git a/ext/SimpleNonlinearSolveADLinearSolveExt.jl b/ext/SimpleNonlinearSolveADLinearSolveExt.jl index 96ba916..d0a97eb 100644 --- a/ext/SimpleNonlinearSolveADLinearSolveExt.jl +++ b/ext/SimpleNonlinearSolveADLinearSolveExt.jl @@ -1,8 +1,10 @@ module SimpleNonlinearSolveADLinearSolveExt -using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, +using AbstractDifferentiation, + ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _result_from_storage, _get_tolerance, @maybeinplace const AD = AbstractDifferentiation @@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}() # TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size ad = SciMLBase._unwrap_val(autodiff) ? - AD.ForwardDiffBackend(; chunksize) : - AD.FiniteDifferencesBackend() - return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}( - ad, + AD.ForwardDiffBackend(; chunksize) : + AD.FiniteDifferencesBackend() + return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad, nothing, termination_condition) end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBatchedNewtonRaphson; - abstol=nothing, - reltol=nothing, - maxiters=1000, + abstol = nothing, + reltol = nothing, + maxiters = 1000, kwargs...) iip = isinplace(prob) @assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems." @@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.Success) + retcode = ReturnCode.Success) - solve(LinearProblem(š“™, vec(fā‚™); u0=vec(Ī“x)), alg.linsolve; kwargs...) + solve(LinearProblem(š“™, vec(fā‚™); u0 = vec(Ī“x)), alg.linsolve; kwargs...) xā‚™ .-= Ī“x if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) @@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) end end diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl index e62e2bd..c0faefd 100644 --- a/ext/SimpleNonlinearSolveNNlibExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -1,18 +1,20 @@ module SimpleNonlinearSolveNNlibExt using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace function __init__() SimpleNonlinearSolve.NNlibExtLoaded[] = true return end +# Broyden's method @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; - abstol=nothing, - reltol=nothing, - maxiters=1000, + abstol = nothing, + reltol = nothing, + maxiters = 1000, kwargs...) iip = isinplace(prob) @@ -24,7 +26,7 @@ end storage = _get_storage(mode, u) - xā‚™, xₙ₋₁, Ī“x, Ī“f = ntuple(_ -> copy(u), 4) + xā‚™, xₙ₋₁, Ī“xā‚™, Ī“f = ntuple(_ -> copy(u), 4) T = eltype(u) atol = _get_tolerance(abstol, tc.abstol, T) @@ -41,16 +43,16 @@ end xā‚™ .= xₙ₋₁ .- š“™ā»Ā¹f @maybeinplace iip fā‚™=f(xā‚™) - Ī“x .= xā‚™ .- xₙ₋₁ + Ī“xā‚™ .= xā‚™ .- xₙ₋₁ Ī“f .= fā‚™ .- fₙ₋₁ batched_mul!(reshape(š“™ā»Ā¹f, L, 1, N), š“™ā»Ā¹, reshape(Ī“f, L, 1, N)) - Ī“xįµ€ = reshape(Ī“x, 1, L, N) + Ī“xₙᵀ = reshape(Ī“xā‚™, 1, L, N) - batched_mul!(reshape(xįµ€š“™ā»Ā¹Ī“f, 1, 1, N), Ī“xįµ€, reshape(š“™ā»Ā¹f, L, 1, N)) - batched_mul!(xįµ€š“™ā»Ā¹, Ī“xįµ€, š“™ā»Ā¹) - Ī“x .= (Ī“x .- š“™ā»Ā¹f) ./ (xįµ€š“™ā»Ā¹Ī“f .+ T(1e-5)) - batched_mul!(š“™ā»Ā¹, reshape(Ī“x, L, 1, N), xįµ€š“™ā»Ā¹, one(T), one(T)) + batched_mul!(reshape(xįµ€š“™ā»Ā¹Ī“f, 1, 1, N), Ī“xₙᵀ, reshape(š“™ā»Ā¹f, L, 1, N)) + batched_mul!(xįµ€š“™ā»Ā¹, Ī“xₙᵀ, š“™ā»Ā¹) + Ī“xā‚™ .= (Ī“xā‚™ .- š“™ā»Ā¹f) ./ (xįµ€š“™ā»Ā¹Ī“f .+ T(1e-5)) + batched_mul!(š“™ā»Ā¹, reshape(Ī“xā‚™, L, 1, N), xįµ€š“™ā»Ā¹, one(T), one(T)) if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) @@ -74,7 +76,103 @@ end alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) +end + +# Limited Memory Broyden's method +@views function SciMLBase.__solve(prob::NonlinearProblem, + alg::BatchedLBroyden; + abstol = nothing, + reltol = nothing, + maxiters = 1000, + kwargs...) + iip = isinplace(prob) + + u, f, reconstruct = _construct_batched_problem_structure(prob) + L, N = size(u) + T = eltype(u) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + Ī· = min(maxiters, alg.threshold) + U = fill!(similar(u, (Ī·, L, N)), zero(T)) + Vįµ€ = fill!(similar(u, (L, Ī·, N)), zero(T)) + + xā‚™, xₙ₋₁, Ī“fā‚™ = ntuple(_ -> copy(u), 3) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + @maybeinplace iip fₙ₋₁=f(xā‚™) u + iip && (fā‚™ = copy(fₙ₋₁)) + Ī“xā‚™ = -copy(fₙ₋₁) + Ī·Nx = similar(xā‚™, Ī·, N) + + for i in 1:maxiters + @. xā‚™ = xₙ₋₁ - Ī“xā‚™ + @maybeinplace iip fā‚™=f(xā‚™) + @. Ī“xā‚™ = xā‚™ - xₙ₋₁ + @. Ī“fā‚™ = fā‚™ - fₙ₋₁ + + if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode) + end + + _L = min(i, Ī·) + _U = U[1:_L, :, :] + _Vįµ€ = Vįµ€[:, 1:_L, :] + + idx = mod1(i, Ī·) + + if i > 1 + partial_Ī·Nx = Ī·Nx[1:_L, :] + + _Ī·Nx = reshape(partial_Ī·Nx, 1, :, N) + batched_mul!(_Ī·Nx, reshape(Ī“xā‚™, 1, L, N), _Vįµ€) + batched_mul!(Vįµ€[:, idx:idx, :], _Ī·Nx, _U) + Vįµ€[:, idx, :] .-= Ī“xā‚™ + + _Ī·Nx = reshape(partial_Ī·Nx, :, 1, N) + batched_mul!(_Ī·Nx, _U, reshape(Ī“fā‚™, L, 1, N)) + batched_mul!(U[idx:idx, :, :], _Vįµ€, _Ī·Nx) + U[idx, :, :] .-= Ī“fā‚™ + else + Vįµ€[:, idx, :] .= -Ī“xā‚™ + U[idx, :, :] .= -Ī“fā‚™ + end + + U[idx, :, :] .= (Ī“xā‚™ .- U[idx, :, :]) ./ + (sum(Vįµ€[:, idx, :] .* Ī“fā‚™; dims = 1) .+ + convert(T, 1e-5)) + + _L = min(i + 1, Ī·) + _Ī·Nx = reshape(Ī·Nx[1:_L, :], :, 1, N) + batched_mul!(_Ī·Nx, U[1:_L, :, :], reshape(Ī“fā‚™, L, 1, N)) + batched_mul!(reshape(Ī“xā‚™, L, 1, N), Vįµ€[:, 1:_L, :], _Ī·Nx) + + xₙ₋₁ .= xā‚™ + fₙ₋₁ .= fā‚™ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xā‚™ = storage.u + @maybeinplace iip fā‚™=f(xā‚™) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode = ReturnCode.MaxIters) end end diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index fe7cbcd..88f02eb 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,4 +1,4 @@ -@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: +Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm Ļƒā‚˜įµ¢ā‚™::T = 1.0f-10 Ļƒā‚˜ā‚ā‚“::T = 1.0f+10 diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl index e69de29..5934c8f 100644 --- a/src/batched/lbroyden.jl +++ b/src/batched/lbroyden.jl @@ -0,0 +1,7 @@ +struct BatchedLBroyden{TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + termination_condition::TC + threshold::Int +end + +# Implementation of solve using Package Extensions \ No newline at end of file diff --git a/src/broyden.jl b/src/broyden.jl index 6c5c3ce..adf94b0 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -30,6 +30,9 @@ end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + if SciMLBase.isinplace(prob) + error("Broyden currently only supports out-of-place nonlinear problems") + end tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) @@ -39,10 +42,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; T = eltype(x) J⁻¹ = init_J(x) - if SciMLBase.isinplace(prob) - error("Broyden currently only supports out-of-place nonlinear problems") - end - atol = _get_tolerance(abstol, tc.abstol, T) rtol = _get_tolerance(reltol, tc.reltol, T) @@ -50,8 +49,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; error("Broyden currently doesn't support SAFE_BEST termination modes") end - storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : - nothing + storage = _get_storage(mode, x) termination_condition = tc(storage) xā‚™ = x diff --git a/src/lbroyden.jl b/src/lbroyden.jl index fc2b51a..95ec389 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -11,134 +11,121 @@ Broyden's method. This method is not very stable and can diverge even for very simple problems. This has mostly been tested for neural networks in DeepEquilibriumNetworks.jl. + +!!! note + + To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or + `import NNlib` must be present in your code. """ struct LBroyden{batched, TC <: NLSolveTerminationCondition} <: AbstractSimpleNonlinearSolveAlgorithm termination_condition::TC threshold::Int +end - function LBroyden(; batched = false, threshold::Int = 27, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - return new{batched, typeof(termination_condition)}(termination_condition, threshold) +function LBroyden(; batched = false, threshold::Int = 27, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + if batched + @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." + return BatchedLBroyden(termination_condition, threshold) end + return LBroyden{true, typeof(termination_condition)}(termination_condition, threshold) end -@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden{batched}, args...; +@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - kwargs...) where {batched} + kwargs...) + if SciMLBase.isinplace(prob) + error("LBroyden currently only supports out-of-place nonlinear problems") + end tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) - threshold = min(maxiters, alg.threshold) + Ī· = min(maxiters, alg.threshold) x = float(prob.u0) - batched && @assert ndims(x)==2 "Batched LBroyden only supports 2D arrays" - + # FIXME: The scalar case currently is very inefficient if x isa Number restore_scalar = true x = [x] - f = u -> prob.f(u[], prob.p) + f = u -> [prob.f(u[], prob.p)] else f = Base.Fix2(prob.f, prob.p) restore_scalar = false end - fā‚™ = f(x) + L = length(x) T = eltype(x) - if SciMLBase.isinplace(prob) - error("LBroyden currently only supports out-of-place nonlinear problems") - end - - U, Vįµ€ = _init_lbroyden_state(batched, x, threshold) + U = fill!(similar(x, (Ī·, L)), zero(T)) + Vįµ€ = fill!(similar(x, (L, Ī·)), zero(T)) - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) - - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - error("LBroyden currently doesn't support SAFE_BEST termination modes") - end - - storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : - nothing + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + storage = _get_storage(mode, x) termination_condition = tc(storage) - xā‚™ = x - xₙ₋₁ = x - fₙ₋₁ = fā‚™ - update = fā‚™ + xā‚™, xₙ₋₁, Ī“fā‚™ = ntuple(_ -> copy(x), 3) + fₙ₋₁ = f(x) + Ī“xā‚™ = -copy(fₙ₋₁) + Ī·Nx = similar(xā‚™, Ī·) + for i in 1:maxiters - xā‚™ = xₙ₋₁ .+ update + @. xā‚™ = xₙ₋₁ - Ī“xā‚™ fā‚™ = f(xā‚™) - Ī”xā‚™ = xā‚™ .- xₙ₋₁ - Ī”fā‚™ = fā‚™ .- fₙ₋₁ + @. Ī“xā‚™ = xā‚™ - xₙ₋₁ + @. Ī“fā‚™ = fā‚™ - fₙ₋₁ - if termination_condition(restore_scalar ? [fā‚™] : fā‚™, xā‚™, xₙ₋₁, atol, rtol) + if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, Val(false)) xā‚™ = restore_scalar ? xā‚™[] : xā‚™ - return SciMLBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.Success) + fā‚™ = restore_scalar ? fā‚™[] : fā‚™ + return DiffEqBase.build_solution(prob, alg, xā‚™, fā‚™; retcode) end - _U = selectdim(U, 1, 1:min(threshold, i)) - _Vįµ€ = selectdim(Vįµ€, 2, 1:min(threshold, i)) - - vįµ€ = _rmatvec(_U, _Vįµ€, Ī”xā‚™) - mvec = _matvec(_U, _Vįµ€, Ī”fā‚™) - u = (Ī”xā‚™ .- mvec) ./ (sum(vįµ€ .* Ī”fā‚™) .+ convert(T, 1e-5)) + _L = min(i, Ī·) + _U = U[1:_L, :] + _Vįµ€ = Vįµ€[:, 1:_L] - selectdim(Vįµ€, 2, mod1(i, threshold)) .= vįµ€ - selectdim(U, 1, mod1(i, threshold)) .= u + idx = mod1(i, Ī·) - update = -_matvec(selectdim(U, 1, 1:min(threshold, i + 1)), - selectdim(Vįµ€, 2, 1:min(threshold, i + 1)), fā‚™) + partial_Ī·Nx = Ī·Nx[1:_L] - xₙ₋₁ = xā‚™ - fₙ₋₁ = fā‚™ - end + if i > 1 + _Ī·Nx = reshape(partial_Ī·Nx, 1, :) + mul!(_Ī·Nx, reshape(Ī“xā‚™, 1, L), _Vįµ€) + mul!(Vįµ€[:, idx:idx], _Ī·Nx, _U) + Vįµ€[:, idx] .-= Ī“xā‚™ - xā‚™ = restore_scalar ? xā‚™[] : xā‚™ - return SciMLBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.MaxIters) -end + _Ī·Nx = reshape(partial_Ī·Nx, :, 1) + mul!(_Ī·Nx, _U, reshape(Ī“fā‚™, L, 1)) + mul!(U[idx:idx, :], _Vįµ€, _Ī·Nx) + U[idx, :] .-= Ī“fā‚™ + else + Vįµ€[:, idx] .= -Ī“xā‚™ + U[idx, :] .= -Ī“fā‚™ + end -function _init_lbroyden_state(batched::Bool, x, threshold) - T = eltype(x) - if batched - U = fill!(similar(x, (threshold, size(x, 1), size(x, 2))), zero(T)) - Vįµ€ = fill!(similar(x, (size(x, 1), threshold, size(x, 2))), zero(T)) - else - U = fill!(similar(x, (threshold, length(x))), zero(T)) - Vįµ€ = fill!(similar(x, (length(x), threshold)), zero(T)) - end - return U, Vįµ€ -end + U[idx, :] .= (Ī“xā‚™ .- U[idx, :]) ./ + (sum(Vįµ€[:, idx] .* Ī“fā‚™) .+ + convert(T, 1e-5)) -function _rmatvec(U::AbstractMatrix, Vįµ€::AbstractMatrix, - x::Union{<:AbstractVector, <:Number}) - length(U) == 0 && return x - return -x .+ vec((x' * Vįµ€) * U) -end + _L = min(i + 1, Ī·) + _Ī·Nx = reshape(Ī·Nx[1:_L], :, 1) + mul!(_Ī·Nx, U[1:_L, :], reshape(Ī“fā‚™, L, 1)) + mul!(reshape(Ī“xā‚™, L, 1), Vįµ€[:, 1:_L], _Ī·Nx) -function _rmatvec(U::AbstractArray{T1, 3}, Vįµ€::AbstractArray{T2, 3}, - x::AbstractMatrix) where {T1, T2} - length(U) == 0 && return x - Vįµ€x = sum(Vįµ€ .* reshape(x, size(x, 1), 1, size(x, 2)); dims = 1) - return -x .+ _drdims_sum(U .* permutedims(Vįµ€x, (2, 1, 3)); dims = 1) -end + xₙ₋₁ .= xā‚™ + fₙ₋₁ .= fā‚™ + end -function _matvec(U::AbstractMatrix, Vįµ€::AbstractMatrix, - x::Union{<:AbstractVector, <:Number}) - length(U) == 0 && return x - return -x .+ vec(Vįµ€ * (U * x)) -end + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xā‚™ = storage.u + fā‚™ = f(xā‚™) + end -function _matvec(U::AbstractArray{T1, 3}, Vįµ€::AbstractArray{T2, 3}, - x::AbstractMatrix) where {T1, T2} - length(U) == 0 && return x - xUįµ€ = sum(reshape(x, size(x, 1), 1, size(x, 2)) .* permutedims(U, (2, 1, 3)); dims = 1) - return -x .+ _drdims_sum(xUįµ€ .* Vįµ€; dims = 2) + xā‚™ = restore_scalar ? xā‚™[] : xā‚™ + fā‚™ = restore_scalar ? fā‚™[] : fā‚™ + return DiffEqBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.MaxIters) end - -_drdims_sum(args...; dims = :) = dropdims(sum(args...; dims); dims) From cb3723ff3c9e08f066183d52fc6ff9dd2190c831 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Jun 2023 11:31:09 -0400 Subject: [PATCH 07/12] Revert "LBroyden" This reverts commit e90573724d401d51eb2eebd08e5210d0199e17e4. --- ext/SimpleNonlinearSolveADLinearSolveExt.jl | 25 ++- ext/SimpleNonlinearSolveNNlibExt.jl | 122 ++------------- src/batched/dfsane.jl | 2 +- src/batched/lbroyden.jl | 7 - src/broyden.jl | 10 +- src/lbroyden.jl | 163 +++++++++++--------- 6 files changed, 119 insertions(+), 210 deletions(-) diff --git a/ext/SimpleNonlinearSolveADLinearSolveExt.jl b/ext/SimpleNonlinearSolveADLinearSolveExt.jl index d0a97eb..96ba916 100644 --- a/ext/SimpleNonlinearSolveADLinearSolveExt.jl +++ b/ext/SimpleNonlinearSolveADLinearSolveExt.jl @@ -1,10 +1,8 @@ module SimpleNonlinearSolveADLinearSolveExt -using AbstractDifferentiation, - ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, +using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, - _get_storage, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace const AD = AbstractDifferentiation @@ -22,18 +20,19 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}() # TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size ad = SciMLBase._unwrap_val(autodiff) ? - AD.ForwardDiffBackend(; chunksize) : - AD.FiniteDifferencesBackend() - return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad, + AD.ForwardDiffBackend(; chunksize) : + AD.FiniteDifferencesBackend() + return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}( + ad, nothing, termination_condition) end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBatchedNewtonRaphson; - abstol = nothing, - reltol = nothing, - maxiters = 1000, + abstol=nothing, + reltol=nothing, + maxiters=1000, kwargs...) iip = isinplace(prob) @assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems." @@ -58,9 +57,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode = ReturnCode.Success) + retcode=ReturnCode.Success) - solve(LinearProblem(š“™, vec(fā‚™); u0 = vec(Ī“x)), alg.linsolve; kwargs...) + solve(LinearProblem(š“™, vec(fā‚™); u0=vec(Ī“x)), alg.linsolve; kwargs...) xā‚™ .-= Ī“x if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) @@ -84,7 +83,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode = ReturnCode.MaxIters) + retcode=ReturnCode.MaxIters) end end diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl index c0faefd..e62e2bd 100644 --- a/ext/SimpleNonlinearSolveNNlibExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -1,20 +1,18 @@ module SimpleNonlinearSolveNNlibExt using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, - _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace function __init__() SimpleNonlinearSolve.NNlibExtLoaded[] = true return end -# Broyden's method @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; - abstol = nothing, - reltol = nothing, - maxiters = 1000, + abstol=nothing, + reltol=nothing, + maxiters=1000, kwargs...) iip = isinplace(prob) @@ -26,7 +24,7 @@ end storage = _get_storage(mode, u) - xā‚™, xₙ₋₁, Ī“xā‚™, Ī“f = ntuple(_ -> copy(u), 4) + xā‚™, xₙ₋₁, Ī“x, Ī“f = ntuple(_ -> copy(u), 4) T = eltype(u) atol = _get_tolerance(abstol, tc.abstol, T) @@ -43,16 +41,16 @@ end xā‚™ .= xₙ₋₁ .- š“™ā»Ā¹f @maybeinplace iip fā‚™=f(xā‚™) - Ī“xā‚™ .= xā‚™ .- xₙ₋₁ + Ī“x .= xā‚™ .- xₙ₋₁ Ī“f .= fā‚™ .- fₙ₋₁ batched_mul!(reshape(š“™ā»Ā¹f, L, 1, N), š“™ā»Ā¹, reshape(Ī“f, L, 1, N)) - Ī“xₙᵀ = reshape(Ī“xā‚™, 1, L, N) + Ī“xįµ€ = reshape(Ī“x, 1, L, N) - batched_mul!(reshape(xįµ€š“™ā»Ā¹Ī“f, 1, 1, N), Ī“xₙᵀ, reshape(š“™ā»Ā¹f, L, 1, N)) - batched_mul!(xįµ€š“™ā»Ā¹, Ī“xₙᵀ, š“™ā»Ā¹) - Ī“xā‚™ .= (Ī“xā‚™ .- š“™ā»Ā¹f) ./ (xįµ€š“™ā»Ā¹Ī“f .+ T(1e-5)) - batched_mul!(š“™ā»Ā¹, reshape(Ī“xā‚™, L, 1, N), xįµ€š“™ā»Ā¹, one(T), one(T)) + batched_mul!(reshape(xįµ€š“™ā»Ā¹Ī“f, 1, 1, N), Ī“xįµ€, reshape(š“™ā»Ā¹f, L, 1, N)) + batched_mul!(xįµ€š“™ā»Ā¹, Ī“xįµ€, š“™ā»Ā¹) + Ī“x .= (Ī“x .- š“™ā»Ā¹f) ./ (xįµ€š“™ā»Ā¹Ī“f .+ T(1e-5)) + batched_mul!(š“™ā»Ā¹, reshape(Ī“x, L, 1, N), xįµ€š“™ā»Ā¹, one(T), one(T)) if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) @@ -76,103 +74,7 @@ end alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode = ReturnCode.MaxIters) -end - -# Limited Memory Broyden's method -@views function SciMLBase.__solve(prob::NonlinearProblem, - alg::BatchedLBroyden; - abstol = nothing, - reltol = nothing, - maxiters = 1000, - kwargs...) - iip = isinplace(prob) - - u, f, reconstruct = _construct_batched_problem_structure(prob) - L, N = size(u) - T = eltype(u) - - tc = alg.termination_condition - mode = DiffEqBase.get_termination_mode(tc) - - storage = _get_storage(mode, u) - - Ī· = min(maxiters, alg.threshold) - U = fill!(similar(u, (Ī·, L, N)), zero(T)) - Vįµ€ = fill!(similar(u, (L, Ī·, N)), zero(T)) - - xā‚™, xₙ₋₁, Ī“fā‚™ = ntuple(_ -> copy(u), 3) - - atol = _get_tolerance(abstol, tc.abstol, T) - rtol = _get_tolerance(reltol, tc.reltol, T) - termination_condition = tc(storage) - - @maybeinplace iip fₙ₋₁=f(xā‚™) u - iip && (fā‚™ = copy(fₙ₋₁)) - Ī“xā‚™ = -copy(fₙ₋₁) - Ī·Nx = similar(xā‚™, Ī·, N) - - for i in 1:maxiters - @. xā‚™ = xₙ₋₁ - Ī“xā‚™ - @maybeinplace iip fā‚™=f(xā‚™) - @. Ī“xā‚™ = xā‚™ - xₙ₋₁ - @. Ī“fā‚™ = fā‚™ - fₙ₋₁ - - if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) - retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) - return DiffEqBase.build_solution(prob, - alg, - reconstruct(xā‚™), - reconstruct(fā‚™); - retcode) - end - - _L = min(i, Ī·) - _U = U[1:_L, :, :] - _Vįµ€ = Vįµ€[:, 1:_L, :] - - idx = mod1(i, Ī·) - - if i > 1 - partial_Ī·Nx = Ī·Nx[1:_L, :] - - _Ī·Nx = reshape(partial_Ī·Nx, 1, :, N) - batched_mul!(_Ī·Nx, reshape(Ī“xā‚™, 1, L, N), _Vįµ€) - batched_mul!(Vįµ€[:, idx:idx, :], _Ī·Nx, _U) - Vįµ€[:, idx, :] .-= Ī“xā‚™ - - _Ī·Nx = reshape(partial_Ī·Nx, :, 1, N) - batched_mul!(_Ī·Nx, _U, reshape(Ī“fā‚™, L, 1, N)) - batched_mul!(U[idx:idx, :, :], _Vįµ€, _Ī·Nx) - U[idx, :, :] .-= Ī“fā‚™ - else - Vįµ€[:, idx, :] .= -Ī“xā‚™ - U[idx, :, :] .= -Ī“fā‚™ - end - - U[idx, :, :] .= (Ī“xā‚™ .- U[idx, :, :]) ./ - (sum(Vįµ€[:, idx, :] .* Ī“fā‚™; dims = 1) .+ - convert(T, 1e-5)) - - _L = min(i + 1, Ī·) - _Ī·Nx = reshape(Ī·Nx[1:_L, :], :, 1, N) - batched_mul!(_Ī·Nx, U[1:_L, :, :], reshape(Ī“fā‚™, L, 1, N)) - batched_mul!(reshape(Ī“xā‚™, L, 1, N), Vįµ€[:, 1:_L, :], _Ī·Nx) - - xₙ₋₁ .= xā‚™ - fₙ₋₁ .= fā‚™ - end - - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - xā‚™ = storage.u - @maybeinplace iip fā‚™=f(xā‚™) - end - - return DiffEqBase.build_solution(prob, - alg, - reconstruct(xā‚™), - reconstruct(fā‚™); - retcode = ReturnCode.MaxIters) + retcode=ReturnCode.MaxIters) end end diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index 88f02eb..fe7cbcd 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,4 +1,4 @@ -Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: +@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm Ļƒā‚˜įµ¢ā‚™::T = 1.0f-10 Ļƒā‚˜ā‚ā‚“::T = 1.0f+10 diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl index 5934c8f..e69de29 100644 --- a/src/batched/lbroyden.jl +++ b/src/batched/lbroyden.jl @@ -1,7 +0,0 @@ -struct BatchedLBroyden{TC <: NLSolveTerminationCondition} <: - AbstractBatchedNonlinearSolveAlgorithm - termination_condition::TC - threshold::Int -end - -# Implementation of solve using Package Extensions \ No newline at end of file diff --git a/src/broyden.jl b/src/broyden.jl index adf94b0..6c5c3ce 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -30,9 +30,6 @@ end function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) - if SciMLBase.isinplace(prob) - error("Broyden currently only supports out-of-place nonlinear problems") - end tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) @@ -42,6 +39,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; T = eltype(x) J⁻¹ = init_J(x) + if SciMLBase.isinplace(prob) + error("Broyden currently only supports out-of-place nonlinear problems") + end + atol = _get_tolerance(abstol, tc.abstol, T) rtol = _get_tolerance(reltol, tc.reltol, T) @@ -49,7 +50,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; error("Broyden currently doesn't support SAFE_BEST termination modes") end - storage = _get_storage(mode, x) + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing termination_condition = tc(storage) xā‚™ = x diff --git a/src/lbroyden.jl b/src/lbroyden.jl index 95ec389..fc2b51a 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -11,121 +11,134 @@ Broyden's method. This method is not very stable and can diverge even for very simple problems. This has mostly been tested for neural networks in DeepEquilibriumNetworks.jl. - -!!! note - - To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or - `import NNlib` must be present in your code. """ struct LBroyden{batched, TC <: NLSolveTerminationCondition} <: AbstractSimpleNonlinearSolveAlgorithm termination_condition::TC threshold::Int -end -function LBroyden(; batched = false, threshold::Int = 27, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - if batched - @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." - return BatchedLBroyden(termination_condition, threshold) + function LBroyden(; batched = false, threshold::Int = 27, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + return new{batched, typeof(termination_condition)}(termination_condition, threshold) end - return LBroyden{true, typeof(termination_condition)}(termination_condition, threshold) end -@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...; +@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden{batched}, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - kwargs...) - if SciMLBase.isinplace(prob) - error("LBroyden currently only supports out-of-place nonlinear problems") - end + kwargs...) where {batched} tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) - Ī· = min(maxiters, alg.threshold) + threshold = min(maxiters, alg.threshold) x = float(prob.u0) - # FIXME: The scalar case currently is very inefficient + batched && @assert ndims(x)==2 "Batched LBroyden only supports 2D arrays" + if x isa Number restore_scalar = true x = [x] - f = u -> [prob.f(u[], prob.p)] + f = u -> prob.f(u[], prob.p) else f = Base.Fix2(prob.f, prob.p) restore_scalar = false end - L = length(x) + fā‚™ = f(x) T = eltype(x) - U = fill!(similar(x, (Ī·, L)), zero(T)) - Vįµ€ = fill!(similar(x, (L, Ī·)), zero(T)) + if SciMLBase.isinplace(prob) + error("LBroyden currently only supports out-of-place nonlinear problems") + end + + U, Vįµ€ = _init_lbroyden_state(batched, x, threshold) - atol = _get_tolerance(abstol, tc.abstol, T) - rtol = _get_tolerance(reltol, tc.reltol, T) - storage = _get_storage(mode, x) - termination_condition = tc(storage) + atol = abstol !== nothing ? abstol : + (tc.abstol !== nothing ? tc.abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) + rtol = reltol !== nothing ? reltol : + (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) - xā‚™, xₙ₋₁, Ī“fā‚™ = ntuple(_ -> copy(x), 3) - fₙ₋₁ = f(x) - Ī“xā‚™ = -copy(fₙ₋₁) - Ī·Nx = similar(xā‚™, Ī·) + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + error("LBroyden currently doesn't support SAFE_BEST termination modes") + end + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + termination_condition = tc(storage) + xā‚™ = x + xₙ₋₁ = x + fₙ₋₁ = fā‚™ + update = fā‚™ for i in 1:maxiters - @. xā‚™ = xₙ₋₁ - Ī“xā‚™ + xā‚™ = xₙ₋₁ .+ update fā‚™ = f(xā‚™) - @. Ī“xā‚™ = xā‚™ - xₙ₋₁ - @. Ī“fā‚™ = fā‚™ - fₙ₋₁ + Ī”xā‚™ = xā‚™ .- xₙ₋₁ + Ī”fā‚™ = fā‚™ .- fₙ₋₁ - if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) - retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, Val(false)) + if termination_condition(restore_scalar ? [fā‚™] : fā‚™, xā‚™, xₙ₋₁, atol, rtol) xā‚™ = restore_scalar ? xā‚™[] : xā‚™ - fā‚™ = restore_scalar ? fā‚™[] : fā‚™ - return DiffEqBase.build_solution(prob, alg, xā‚™, fā‚™; retcode) + return SciMLBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.Success) end - _L = min(i, Ī·) - _U = U[1:_L, :] - _Vįµ€ = Vįµ€[:, 1:_L] - - idx = mod1(i, Ī·) + _U = selectdim(U, 1, 1:min(threshold, i)) + _Vįµ€ = selectdim(Vįµ€, 2, 1:min(threshold, i)) - partial_Ī·Nx = Ī·Nx[1:_L] + vįµ€ = _rmatvec(_U, _Vįµ€, Ī”xā‚™) + mvec = _matvec(_U, _Vįµ€, Ī”fā‚™) + u = (Ī”xā‚™ .- mvec) ./ (sum(vįµ€ .* Ī”fā‚™) .+ convert(T, 1e-5)) - if i > 1 - _Ī·Nx = reshape(partial_Ī·Nx, 1, :) - mul!(_Ī·Nx, reshape(Ī“xā‚™, 1, L), _Vįµ€) - mul!(Vįµ€[:, idx:idx], _Ī·Nx, _U) - Vįµ€[:, idx] .-= Ī“xā‚™ + selectdim(Vįµ€, 2, mod1(i, threshold)) .= vįµ€ + selectdim(U, 1, mod1(i, threshold)) .= u - _Ī·Nx = reshape(partial_Ī·Nx, :, 1) - mul!(_Ī·Nx, _U, reshape(Ī“fā‚™, L, 1)) - mul!(U[idx:idx, :], _Vįµ€, _Ī·Nx) - U[idx, :] .-= Ī“fā‚™ - else - Vįµ€[:, idx] .= -Ī“xā‚™ - U[idx, :] .= -Ī“fā‚™ - end + update = -_matvec(selectdim(U, 1, 1:min(threshold, i + 1)), + selectdim(Vįµ€, 2, 1:min(threshold, i + 1)), fā‚™) - U[idx, :] .= (Ī“xā‚™ .- U[idx, :]) ./ - (sum(Vįµ€[:, idx] .* Ī“fā‚™) .+ - convert(T, 1e-5)) + xₙ₋₁ = xā‚™ + fₙ₋₁ = fā‚™ + end - _L = min(i + 1, Ī·) - _Ī·Nx = reshape(Ī·Nx[1:_L], :, 1) - mul!(_Ī·Nx, U[1:_L, :], reshape(Ī“fā‚™, L, 1)) - mul!(reshape(Ī“xā‚™, L, 1), Vįµ€[:, 1:_L], _Ī·Nx) + xā‚™ = restore_scalar ? xā‚™[] : xā‚™ + return SciMLBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.MaxIters) +end - xₙ₋₁ .= xā‚™ - fₙ₋₁ .= fā‚™ +function _init_lbroyden_state(batched::Bool, x, threshold) + T = eltype(x) + if batched + U = fill!(similar(x, (threshold, size(x, 1), size(x, 2))), zero(T)) + Vįµ€ = fill!(similar(x, (size(x, 1), threshold, size(x, 2))), zero(T)) + else + U = fill!(similar(x, (threshold, length(x))), zero(T)) + Vįµ€ = fill!(similar(x, (length(x), threshold)), zero(T)) end + return U, Vįµ€ +end - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - xā‚™ = storage.u - fā‚™ = f(xā‚™) - end +function _rmatvec(U::AbstractMatrix, Vįµ€::AbstractMatrix, + x::Union{<:AbstractVector, <:Number}) + length(U) == 0 && return x + return -x .+ vec((x' * Vįµ€) * U) +end - xā‚™ = restore_scalar ? xā‚™[] : xā‚™ - fā‚™ = restore_scalar ? fā‚™[] : fā‚™ - return DiffEqBase.build_solution(prob, alg, xā‚™, fā‚™; retcode = ReturnCode.MaxIters) +function _rmatvec(U::AbstractArray{T1, 3}, Vįµ€::AbstractArray{T2, 3}, + x::AbstractMatrix) where {T1, T2} + length(U) == 0 && return x + Vįµ€x = sum(Vįµ€ .* reshape(x, size(x, 1), 1, size(x, 2)); dims = 1) + return -x .+ _drdims_sum(U .* permutedims(Vįµ€x, (2, 1, 3)); dims = 1) end + +function _matvec(U::AbstractMatrix, Vįµ€::AbstractMatrix, + x::Union{<:AbstractVector, <:Number}) + length(U) == 0 && return x + return -x .+ vec(Vįµ€ * (U * x)) +end + +function _matvec(U::AbstractArray{T1, 3}, Vįµ€::AbstractArray{T2, 3}, + x::AbstractMatrix) where {T1, T2} + length(U) == 0 && return x + xUįµ€ = sum(reshape(x, size(x, 1), 1, size(x, 2)) .* permutedims(U, (2, 1, 3)); dims = 1) + return -x .+ _drdims_sum(xUįµ€ .* Vįµ€; dims = 2) +end + +_drdims_sum(args...; dims = :) = dropdims(sum(args...; dims); dims) From 908a98e4c031804f01e921fc84e252f5e54e0047 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Jun 2023 11:31:24 -0400 Subject: [PATCH 08/12] Format --- ext/SimpleNonlinearSolveADLinearSolveExt.jl | 25 +++++++++++---------- ext/SimpleNonlinearSolveNNlibExt.jl | 11 ++++----- src/SimpleNonlinearSolve.jl | 3 ++- src/batched/broyden.jl | 2 +- src/batched/dfsane.jl | 24 ++++++++++---------- src/batched/lbroyden.jl | 1 + src/batched/raphson.jl | 2 +- src/batched/utils.jl | 4 ++-- src/raphson.jl | 10 ++++----- 9 files changed, 43 insertions(+), 39 deletions(-) diff --git a/ext/SimpleNonlinearSolveADLinearSolveExt.jl b/ext/SimpleNonlinearSolveADLinearSolveExt.jl index 96ba916..d0a97eb 100644 --- a/ext/SimpleNonlinearSolveADLinearSolveExt.jl +++ b/ext/SimpleNonlinearSolveADLinearSolveExt.jl @@ -1,8 +1,10 @@ module SimpleNonlinearSolveADLinearSolveExt -using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, +using AbstractDifferentiation, + ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _result_from_storage, _get_tolerance, @maybeinplace const AD = AbstractDifferentiation @@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}() # TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size ad = SciMLBase._unwrap_val(autodiff) ? - AD.ForwardDiffBackend(; chunksize) : - AD.FiniteDifferencesBackend() - return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}( - ad, + AD.ForwardDiffBackend(; chunksize) : + AD.FiniteDifferencesBackend() + return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad, nothing, termination_condition) end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBatchedNewtonRaphson; - abstol=nothing, - reltol=nothing, - maxiters=1000, + abstol = nothing, + reltol = nothing, + maxiters = 1000, kwargs...) iip = isinplace(prob) @assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems." @@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.Success) + retcode = ReturnCode.Success) - solve(LinearProblem(š“™, vec(fā‚™); u0=vec(Ī“x)), alg.linsolve; kwargs...) + solve(LinearProblem(š“™, vec(fā‚™); u0 = vec(Ī“x)), alg.linsolve; kwargs...) xā‚™ .-= Ī“x if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) @@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) end end diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl index e62e2bd..5b06530 100644 --- a/ext/SimpleNonlinearSolveNNlibExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -1,7 +1,8 @@ module SimpleNonlinearSolveNNlibExt using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _init_š“™, _result_from_storage, _get_tolerance, @maybeinplace function __init__() SimpleNonlinearSolve.NNlibExtLoaded[] = true @@ -10,9 +11,9 @@ end @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; - abstol=nothing, - reltol=nothing, - maxiters=1000, + abstol = nothing, + reltol = nothing, + maxiters = 1000, kwargs...) iip = isinplace(prob) @@ -74,7 +75,7 @@ end alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) end end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 49aef9f..23b332f 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -22,7 +22,8 @@ abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonline abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end -abstract type AbstractBatchedNonlinearSolveAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end +abstract type AbstractBatchedNonlinearSolveAlgorithm <: + AbstractSimpleNonlinearSolveAlgorithm end include("utils.jl") include("bisection.jl") diff --git a/src/batched/broyden.jl b/src/batched/broyden.jl index 8575488..ed3cd5d 100644 --- a/src/batched/broyden.jl +++ b/src/batched/broyden.jl @@ -1,5 +1,5 @@ struct BatchedBroyden{TC <: NLSolveTerminationCondition} <: - AbstractBatchedNonlinearSolveAlgorithm + AbstractBatchedNonlinearSolveAlgorithm termination_condition::TC end diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index fe7cbcd..09fc37f 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,4 +1,4 @@ -@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: +Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm Ļƒā‚˜įµ¢ā‚™::T = 1.0f-10 Ļƒā‚˜ā‚ā‚“::T = 1.0f+10 @@ -10,17 +10,17 @@ nā‚‘ā‚“ā‚š::Int = 2 Ī·ā‚›::F = (fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚, n, xā‚™, fā‚™) -> fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚ ./ n .^ 2 termination_condition::TC = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol=nothing, - reltol=nothing) + abstol = nothing, + reltol = nothing) max_inner_iterations::Int = 1000 end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBatchedDFSane, args...; - abstol=nothing, - reltol=nothing, - maxiters=100, + abstol = nothing, + reltol = nothing, + maxiters = 100, kwargs...) iip = isinplace(prob) @@ -60,7 +60,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, return fā‚“ end - @maybeinplace iip fₙ₋₁ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, xā‚™) xā‚™ + @maybeinplace iip fₙ₋₁=ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, xā‚™) xā‚™ iip && (fā‚™ = similar(fₙ₋₁)) ā„‹ = repeat(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, M, 1) fĢ„ = similar(ā„‹, 1, N) @@ -79,7 +79,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, fill!(α₋, α₁) @. xā‚™ = xₙ₋₁ + Ī±ā‚Š * š’¹ - @maybeinplace iip fā‚™ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) + @maybeinplace iip fā‚™=ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) for _ in 1:(alg.max_inner_iterations) š’ø = @. fĢ„ + Ī· - γ * Ī±ā‚Š^2 * fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚ @@ -90,7 +90,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, Ļ„ā‚˜įµ¢ā‚™ * Ī±ā‚Š, Ļ„ā‚˜ā‚ā‚“ * Ī±ā‚Š) @. xā‚™ = xₙ₋₁ - α₋ * š’¹ - @maybeinplace iip fā‚™ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) + @maybeinplace iip fā‚™=ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) (sum(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ .≤ š’ø) ≄ N Ć· 2) && break @@ -98,7 +98,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, Ļ„ā‚˜įµ¢ā‚™ * α₋, Ļ„ā‚˜ā‚ā‚“ * α₋) @. xā‚™ = xₙ₋₁ + Ī±ā‚Š * š’¹ - @maybeinplace iip fā‚™ = ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) + @maybeinplace iip fā‚™=ff!(fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™, xā‚™) end if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) @@ -129,12 +129,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES xā‚™ = storage.u - @maybeinplace iip fā‚™ = f(xā‚™) + @maybeinplace iip fā‚™=f(xā‚™) end return DiffEqBase.build_solution(prob, alg, reconstruct(xā‚™), reconstruct(fā‚™); - retcode=ReturnCode.MaxIters) + retcode = ReturnCode.MaxIters) end diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl index e69de29..8b13789 100644 --- a/src/batched/lbroyden.jl +++ b/src/batched/lbroyden.jl @@ -0,0 +1 @@ + diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl index 2be565f..63c03ec 100644 --- a/src/batched/raphson.jl +++ b/src/batched/raphson.jl @@ -1,5 +1,5 @@ struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <: - AbstractBatchedNonlinearSolveAlgorithm + AbstractBatchedNonlinearSolveAlgorithm autodiff::AD linsolve::LS termination_condition::TC diff --git a/src/batched/utils.jl b/src/batched/utils.jl index b5dbd59..7b85011 100644 --- a/src/batched/utils.jl +++ b/src/batched/utils.jl @@ -1,4 +1,4 @@ -macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing) +macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing} = nothing) @assert expr.head == :(=) x1, x2 = expr.args @assert x2.head == :call @@ -64,7 +64,7 @@ function _result_from_storage(storage::NLSolveSafeTerminationResult, xā‚™, fā‚™, return ReturnCode.Success, xā‚™, fā‚™ else if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - @maybeinplace iip fā‚™ = f(xā‚™) + @maybeinplace iip fā‚™=f(xā‚™) return ReturnCode.Terminated, storage.u, fā‚™ else return ReturnCode.Terminated, xā‚™, fā‚™ diff --git a/src/raphson.jl b/src/raphson.jl index 59eb1ed..2621c58 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -35,7 +35,7 @@ and static array problems. """ struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end -function SimpleNewtonRaphson(; batched=false, +function SimpleNewtonRaphson(; batched = false, chunk_size = Val{0}(), autodiff = Val{true}(), diff_type = Val{:forward}, @@ -46,10 +46,10 @@ function SimpleNewtonRaphson(; batched=false, if batched @assert ADLinearSolveExtLoaded[] "Please install and load `LinearSolve.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." termination_condition = ismissing(termination_condition) ? - NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing) : - termination_condition + NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing) : + termination_condition return SimpleBatchedNewtonRaphson(; chunk_size, autodiff, diff_type, From aedccfb4bd54d62a2e245e91c685adc2affe3cf5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 15:15:54 -0400 Subject: [PATCH 09/12] Add Inplace tests --- Project.toml | 2 +- src/SimpleNonlinearSolve.jl | 4 +-- src/batched/dfsane.jl | 2 +- src/batched/lbroyden.jl | 1 - test/inplace.jl | 52 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 ++--- 6 files changed, 59 insertions(+), 9 deletions(-) delete mode 100644 src/batched/lbroyden.jl create mode 100644 test/inplace.jl diff --git a/Project.toml b/Project.toml index f1f472d..258f009 100644 --- a/Project.toml +++ b/Project.toml @@ -21,8 +21,8 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [extensions] -SimpleNonlinearSolveNNlibExt = "NNlib" SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"] +SimpleNonlinearSolveNNlibExt = "NNlib" [compat] AbstractDifferentiation = "0.5" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 23b332f..cd48556 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -45,7 +45,6 @@ include("batched/utils.jl") include("batched/raphson.jl") include("batched/dfsane.jl") include("batched/broyden.jl") -include("batched/lbroyden.jl") import PrecompileTools @@ -79,7 +78,6 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld -export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane, - BatchedLBroyden +export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane end # module diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index 09fc37f..a394517 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,5 +1,5 @@ Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: - AbstractBatchedNonlinearSolveAlgorithm + AbstractBatchedNonlinearSolveAlgorithm Ļƒā‚˜įµ¢ā‚™::T = 1.0f-10 Ļƒā‚˜ā‚ā‚“::T = 1.0f+10 Ļƒā‚::T = 1.0f0 diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl deleted file mode 100644 index 8b13789..0000000 --- a/src/batched/lbroyden.jl +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/inplace.jl b/test/inplace.jl new file mode 100644 index 0000000..4c43a1d --- /dev/null +++ b/test/inplace.jl @@ -0,0 +1,52 @@ +using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, + NNlib, AbstractDifferentiation, LinearSolve + +# Supported Solvers: BatchedBroyden, SimpleBatchedDFSane +function f!(du::AbstractArray{<:Number, N}, + u::AbstractArray{<:Number, N}, + p::AbstractVector) where {N} + u_ = reshape(u, :, size(u, N)) + du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :), + ntuple(_ -> 1, N - 1)..., + size(u, N)) + return du +end + +function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector) + du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :) + return du +end + +function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector) + du .= sum(abs2, u) .- p + return du +end + +@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true), + SimpleDFSane(batched = true)) + @testset "T: $T" for T in (Float32, Float64) + p = rand(T, 5) + @testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5)) + u0 = ones(T, sz) + prob = NonlinearProblem{true}(f!, u0, p) + + sol = solve(prob, solver) + + @test SciMLBase.successful_retcode(sol.retcode) + + @test sol.residā‰ˆzero(sol.resid) atol=5e-3 + end + + p = rand(T, 1) + @testset "size(u0): $sz" for sz in ((3,), (5,), (10,)) + u0 = ones(T, sz) + prob = NonlinearProblem{true}(f!, u0, p) + + sol = solve(prob, solver) + + @test SciMLBase.successful_retcode(sol.retcode) + + @test sol.residā‰ˆzero(sol.resid) atol=5e-3 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 94a0086..bea57ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,15 @@ -using Pkg using SafeTestsets -const LONGER_TESTS = false const GROUP = get(ENV, "GROUP", "All") -const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") @time begin if GROUP == "All" || GROUP == "Core" @time @safetestset "Basic Tests + Some AD" begin include("basictests.jl") end + + @time @safetestset "Inplace Tests" begin + include("inplace.jl") + end end end From 534e1dbf1da01bca27472756430e5c760b829332 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 15:29:46 -0400 Subject: [PATCH 10/12] Allow v1.6 to work --- src/batched/dfsane.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index a394517..a76c768 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -42,7 +42,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, Ī±ā‚Š, α₋ = similar(u, 1, N), similar(u, 1, N) Ļƒā‚™ = fill(T(alg.Ļƒā‚), 1, N) š’¹ = similar(Ļƒā‚™, L, N) - (; M, nā‚‘ā‚“ā‚š) = alg + M = alg.M + nā‚‘ā‚“ā‚š = alg.nā‚‘ā‚“ā‚š xā‚™, xₙ₋₁, fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ā‚‹ā‚, fā‚ā‚™ā‚’įµ£ā‚˜ā‚Žā‚™ = copy(u), copy(u), similar(u, 1, N), similar(u, 1, N) From 9cb861a1b8c45909e9582b843d423a10d0a23405 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 10 Jul 2023 12:47:02 -0400 Subject: [PATCH 11/12] Update batched Raphson to use same parameters as unbatched --- Project.toml | 7 -- ext/SimpleNonlinearSolveADLinearSolveExt.jl | 90 --------------------- src/SimpleNonlinearSolve.jl | 3 +- src/batched/dfsane.jl | 4 +- src/batched/raphson.jl | 77 +++++++++++++++++- src/dfsane.jl | 2 +- src/raphson.jl | 17 ++-- test/Project.toml | 2 - test/basictests.jl | 12 ++- test/inplace.jl | 4 +- 10 files changed, 98 insertions(+), 120 deletions(-) delete mode 100644 ext/SimpleNonlinearSolveADLinearSolveExt.jl diff --git a/Project.toml b/Project.toml index 258f009..89848eb 100644 --- a/Project.toml +++ b/Project.toml @@ -16,21 +16,16 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [extensions] -SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"] SimpleNonlinearSolveNNlibExt = "NNlib" [compat] -AbstractDifferentiation = "0.5" ArrayInterface = "6, 7" DiffEqBase = "6.126" FiniteDiff = "2" ForwardDiff = "0.10.3" -LinearSolve = "2" NNlib = "0.8, 0.9" PackageExtensionCompat = "1" PrecompileTools = "1" @@ -40,6 +35,4 @@ StaticArraysCore = "1.4" julia = "1.6" [extras] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/ext/SimpleNonlinearSolveADLinearSolveExt.jl b/ext/SimpleNonlinearSolveADLinearSolveExt.jl deleted file mode 100644 index d0a97eb..0000000 --- a/ext/SimpleNonlinearSolveADLinearSolveExt.jl +++ /dev/null @@ -1,90 +0,0 @@ -module SimpleNonlinearSolveADLinearSolveExt - -using AbstractDifferentiation, - ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve, - SimpleNonlinearSolve, SciMLBase -import SimpleNonlinearSolve: _construct_batched_problem_structure, - _get_storage, _result_from_storage, _get_tolerance, @maybeinplace - -const AD = AbstractDifferentiation - -function __init__() - SimpleNonlinearSolve.ADLinearSolveExtLoaded[] = true - return -end - -function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}(), - autodiff = Val{true}(), - diff_type = Val{:forward}, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - # TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl - chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size - ad = SciMLBase._unwrap_val(autodiff) ? - AD.ForwardDiffBackend(; chunksize) : - AD.FiniteDifferencesBackend() - return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad, - nothing, - termination_condition) -end - -function SciMLBase.__solve(prob::NonlinearProblem, - alg::SimpleBatchedNewtonRaphson; - abstol = nothing, - reltol = nothing, - maxiters = 1000, - kwargs...) - iip = isinplace(prob) - @assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems." - u, f, reconstruct = _construct_batched_problem_structure(prob) - - tc = alg.termination_condition - mode = DiffEqBase.get_termination_mode(tc) - - storage = _get_storage(mode, u) - - xā‚™, xₙ₋₁, Ī“x = copy(u), copy(u), copy(u) - T = eltype(u) - - atol = _get_tolerance(abstol, tc.abstol, T) - rtol = _get_tolerance(reltol, tc.reltol, T) - termination_condition = tc(storage) - - for i in 1:maxiters - fā‚™, (š“™,) = AD.value_and_jacobian(alg.autodiff, f, xā‚™) - - iszero(fā‚™) && return DiffEqBase.build_solution(prob, - alg, - reconstruct(xā‚™), - reconstruct(fā‚™); - retcode = ReturnCode.Success) - - solve(LinearProblem(š“™, vec(fā‚™); u0 = vec(Ī“x)), alg.linsolve; kwargs...) - xā‚™ .-= Ī“x - - if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) - retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) - return DiffEqBase.build_solution(prob, - alg, - reconstruct(xā‚™), - reconstruct(fā‚™); - retcode) - end - - xₙ₋₁ .= xā‚™ - end - - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - xā‚™ = storage.u - fā‚™ = f(xā‚™) - end - - return DiffEqBase.build_solution(prob, - alg, - reconstruct(xā‚™), - reconstruct(fā‚™); - retcode = ReturnCode.MaxIters) -end - -end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index cd48556..bc57d12 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -15,7 +15,6 @@ function __init__() @require_extensions end -const ADLinearSolveExtLoaded = Ref{Bool}(false) const NNlibExtLoaded = Ref{Bool}(false) abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end @@ -78,6 +77,6 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld -export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane +export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane end # module diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index a76c768..60bb6ae 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,4 +1,4 @@ -Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: +Base.@kwdef struct BatchedSimpleDFSane{T, F, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm Ļƒā‚˜įµ¢ā‚™::T = 1.0f-10 Ļƒā‚˜ā‚ā‚“::T = 1.0f+10 @@ -16,7 +16,7 @@ Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} end function SciMLBase.__solve(prob::NonlinearProblem, - alg::SimpleBatchedDFSane, + alg::BatchedSimpleDFSane, args...; abstol = nothing, reltol = nothing, diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl index 63c03ec..323c07e 100644 --- a/src/batched/raphson.jl +++ b/src/batched/raphson.jl @@ -1,8 +1,77 @@ -struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <: +struct BatchedSimpleNewtonRaphson{CS, AD, FDT, TC <: NLSolveTerminationCondition} <: AbstractBatchedNonlinearSolveAlgorithm - autodiff::AD - linsolve::LS termination_condition::TC end -# Implementation of solve using Package Extensions +alg_autodiff(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = AD +diff_type(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = FDT + +function BatchedSimpleNewtonRaphson(; chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + return BatchedSimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), + SciMLBase._unwrap_val(autodiff), + SciMLBase._unwrap_val(diff_type), typeof(termination_condition)}(termination_condition) +end + +function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson; + abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + iip = SciMLBase.isinplace(prob) + @assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems." + u, f, reconstruct = _construct_batched_problem_structure(prob) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + xā‚™, xₙ₋₁ = copy(u), copy(u) + T = eltype(u) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + for i in 1:maxiters + if alg_autodiff(alg) + fā‚™, š“™ = value_derivative(f, xā‚™) + else + fā‚™ = f(xā‚™) + š“™ = FiniteDiff.finite_difference_jacobian(f, xā‚™, diff_type(alg), eltype(xā‚™), fā‚™) + end + + iszero(fā‚™) && return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode = ReturnCode.Success) + + Ī“x = reshape(š“™ \ vec(fā‚™), size(xā‚™)) + xā‚™ .-= Ī“x + + if termination_condition(fā‚™, xā‚™, xₙ₋₁, atol, rtol) + retcode, xā‚™, fā‚™ = _result_from_storage(storage, xā‚™, fā‚™, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode) + end + + xₙ₋₁ .= xā‚™ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xā‚™ = storage.u + fā‚™ = f(xā‚™) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xā‚™), + reconstruct(fā‚™); + retcode = ReturnCode.MaxIters) +end diff --git a/src/dfsane.jl b/src/dfsane.jl index d4bc777..2e52cde 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -73,7 +73,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = batched::Bool = false, max_inner_iterations = 1000) if batched - return SimpleBatchedDFSane(; Ļƒā‚˜įµ¢ā‚™ = σ_min, + return BatchedSimpleDFSane(; Ļƒā‚˜įµ¢ā‚™ = σ_min, Ļƒā‚˜ā‚ā‚“ = σ_max, Ļƒā‚ = σ_1, M, diff --git a/src/raphson.jl b/src/raphson.jl index 2621c58..48b8f75 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -2,7 +2,8 @@ SimpleNewtonRaphson(; batched = false, chunk_size = Val{0}(), autodiff = Val{true}(), - diff_type = Val{:forward}) + diff_type = Val{:forward}, + termination_condition = missing) A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar and static array problems. @@ -27,11 +28,8 @@ and static array problems. - `diff_type`: the type of finite differencing used if `autodiff = false`. Defaults to `Val{:forward}` for forward finite differences. For more details on the choices, see the [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) documentation. - -!!! note - - To use the `batched` version, remember to load `AbstractDifferentiation` and - `LinearSolve`. +- `termination_condition`: control the termination of the algorithm. (Only works for batched + problems) """ struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end @@ -44,16 +42,19 @@ function SimpleNewtonRaphson(; batched = false, throw(ArgumentError("`termination_condition` is currently only supported for batched problems")) end if batched - @assert ADLinearSolveExtLoaded[] "Please install and load `LinearSolve.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." + # @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." termination_condition = ismissing(termination_condition) ? NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; abstol = nothing, reltol = nothing) : termination_condition - return SimpleBatchedNewtonRaphson(; chunk_size, + return BatchedSimpleNewtonRaphson(; chunk_size, autodiff, diff_type, termination_condition) + return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), + SciMLBase._unwrap_val(autodiff), + SciMLBase._unwrap_val(diff_type)}() end return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff), diff --git a/test/Project.toml b/test/Project.toml index 280bf2a..469f302 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,8 @@ [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/test/basictests.jl b/test/basictests.jl index 4fe316b..4b47c6a 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,5 +1,5 @@ using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, - NNlib, AbstractDifferentiation, LinearSolve + NNlib const BATCHED_BROYDEN_SOLVERS = [] const BROYDEN_SOLVERS = [] @@ -7,6 +7,7 @@ const BATCHED_LBROYDEN_SOLVERS = [] const LBROYDEN_SOLVERS = [] const BATCHED_DFSANE_SOLVERS = [] const DFSANE_SOLVERS = [] +const BATCHED_RAPHSON_SOLVERS = [] for mode in instances(NLSolveTerminationMode.T) if mode ∈ @@ -23,6 +24,12 @@ for mode in instances(NLSolveTerminationMode.T) push!(BATCHED_LBROYDEN_SOLVERS, LBroyden(; batched = true, termination_condition)) push!(DFSANE_SOLVERS, SimpleDFSane(; batched = false, termination_condition)) push!(BATCHED_DFSANE_SOLVERS, SimpleDFSane(; batched = true, termination_condition)) + push!(BATCHED_RAPHSON_SOLVERS, + SimpleNewtonRaphson(; batched = true, + termination_condition)) + push!(BATCHED_RAPHSON_SOLVERS, + SimpleNewtonRaphson(; batched = true, autodiff = false, + termination_condition)) end # SimpleNewtonRaphson @@ -483,7 +490,8 @@ sol = solve(probN, Broyden(batched = true)) @testset "Batched Solver: $(nameof(typeof(alg)))" for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS..., - BATCHED_DFSANE_SOLVERS...) + BATCHED_DFSANE_SOLVERS..., + BATCHED_RAPHSON_SOLVERS...) sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3) @test sol.retcode == ReturnCode.Success diff --git a/test/inplace.jl b/test/inplace.jl index 4c43a1d..0f2d747 100644 --- a/test/inplace.jl +++ b/test/inplace.jl @@ -1,7 +1,7 @@ using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, - NNlib, AbstractDifferentiation, LinearSolve + NNlib -# Supported Solvers: BatchedBroyden, SimpleBatchedDFSane +# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane function f!(du::AbstractArray{<:Number, N}, u::AbstractArray{<:Number, N}, p::AbstractVector) where {N} From 87c9c0d668cd2a1c44aeb4d9e13821d20e98aa76 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 10 Jul 2023 12:49:53 -0400 Subject: [PATCH 12/12] formatting --- test/basictests.jl | 3 ++- test/inplace.jl | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/basictests.jl b/test/basictests.jl index 4b47c6a..d5f85cd 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,4 +1,5 @@ -using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, +using SimpleNonlinearSolve, + StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, NNlib const BATCHED_BROYDEN_SOLVERS = [] diff --git a/test/inplace.jl b/test/inplace.jl index 0f2d747..886e820 100644 --- a/test/inplace.jl +++ b/test/inplace.jl @@ -1,4 +1,5 @@ -using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, +using SimpleNonlinearSolve, + StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, NNlib # Supported Solvers: BatchedBroyden, BatchedSimpleDFSane