diff --git a/Project.toml b/Project.toml index 28d6e75..89848eb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.16" +version = "0.1.17" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -9,38 +9,30 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [extensions] -SimpleBatchedNonlinearSolveExt = "NNlib" +SimpleNonlinearSolveNNlibExt = "NNlib" [compat] ArrayInterface = "6, 7" -DiffEqBase = "6.123.0" +DiffEqBase = "6.126" FiniteDiff = "2" ForwardDiff = "0.10.3" NNlib = "0.8, 0.9" +PackageExtensionCompat = "1" +PrecompileTools = "1" Reexport = "0.2, 1" -Requires = "1" SciMLBase = "1.73" -PrecompileTools = "1" StaticArraysCore = "1.4" julia = "1.6" [extras] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"] diff --git a/ext/SimpleBatchedNonlinearSolveExt.jl b/ext/SimpleBatchedNonlinearSolveExt.jl deleted file mode 100644 index dd76288..0000000 --- a/ext/SimpleBatchedNonlinearSolveExt.jl +++ /dev/null @@ -1,90 +0,0 @@ -module SimpleBatchedNonlinearSolveExt - -using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase - -isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib) - -_batch_transpose(x) = reshape(x, 1, size(x)...) - -_batched_mul(x, y) = x * y - -function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T} - return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2) -end - -function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T} - return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y) -end - -function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2} - return batched_mul(x, y) -end - -function _init_J_batched(x::AbstractMatrix{T}) where {T} - J = ArrayInterface.zeromatrix(x[:, 1]) - if ismutable(x) - J[diagind(J)] .= one(eltype(x)) - else - J += I - end - return repeat(J, 1, 1, size(x, 2)) -end - -function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...; - abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) - tc = alg.termination_condition - mode = DiffEqBase.get_termination_mode(tc) - f = Base.Fix2(prob.f, prob.p) - x = float(prob.u0) - - if ndims(x) != 2 - error("`batch` mode works only if `ndims(prob.u0) == 2`") - end - - fₙ = f(x) - T = eltype(x) - J⁻¹ = _init_J_batched(x) - - if SciMLBase.isinplace(prob) - error("Broyden currently only supports out-of-place nonlinear problems") - end - - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) - - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES - error("Broyden currently doesn't support SAFE_BEST termination modes") - end - - storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : - nothing - termination_condition = tc(storage) - - xₙ = x - xₙ₋₁ = x - fₙ₋₁ = fₙ - for i in 1:maxiters - xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁) - fₙ = f(xₙ) - Δxₙ = xₙ .- xₙ₋₁ - Δfₙ = fₙ .- fₙ₋₁ - J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ) - J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./ - (_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))), - _batched_mul(_batch_transpose(Δxₙ), J⁻¹)) - - if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) - end - - xₙ₋₁ = xₙ - fₙ₋₁ = fₙ - end - - return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters) -end - -end diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl new file mode 100644 index 0000000..5b06530 --- /dev/null +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -0,0 +1,81 @@ +module SimpleNonlinearSolveNNlibExt + +using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase +import SimpleNonlinearSolve: _construct_batched_problem_structure, + _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace + +function __init__() + SimpleNonlinearSolve.NNlibExtLoaded[] = true + return +end + +@views function SciMLBase.__solve(prob::NonlinearProblem, + alg::BatchedBroyden; + abstol = nothing, + reltol = nothing, + maxiters = 1000, + kwargs...) + iip = isinplace(prob) + + u, f, reconstruct = _construct_batched_problem_structure(prob) + L, N = size(u) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4) + T = eltype(u) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + 𝓙⁻¹ = _init_𝓙(xₙ) # L × L × N + 𝓙⁻¹f, xᵀ𝓙⁻¹δf, xᵀ𝓙⁻¹ = similar(𝓙⁻¹, L, N), similar(𝓙⁻¹, 1, N), similar(𝓙⁻¹, 1, L, N) + + @maybeinplace iip fₙ₋₁=f(xₙ) u + iip && (fₙ = copy(fₙ₋₁)) + for n in 1:maxiters + batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(fₙ₋₁, L, 1, N)) + xₙ .= xₙ₋₁ .- 𝓙⁻¹f + + @maybeinplace iip fₙ=f(xₙ) + δx .= xₙ .- xₙ₋₁ + δf .= fₙ .- fₙ₋₁ + + batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N)) + δxᵀ = reshape(δx, 1, L, N) + + batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N)) + batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹) + δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5)) + batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T)) + + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode) + end + + xₙ₋₁ .= xₙ + fₙ₋₁ .= fₙ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xₙ = storage.u + @maybeinplace iip fₙ=f(xₙ) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode = ReturnCode.MaxIters) +end + +end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 8749aa7..bc57d12 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -10,22 +10,19 @@ using DiffEqBase @reexport using SciMLBase -if !isdefined(Base, :get_extension) - using Requires -end - +using PackageExtensionCompat function __init__() - @static if !isdefined(Base, :get_extension) - @require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin - include("../ext/SimpleBatchedNonlinearSolveExt.jl") - end - end + @require_extensions end +const NNlibExtLoaded = Ref{Bool}(false) + abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end +abstract type AbstractBatchedNonlinearSolveAlgorithm <: + AbstractSimpleNonlinearSolveAlgorithm end include("utils.jl") include("bisection.jl") @@ -42,6 +39,12 @@ include("ad.jl") include("halley.jl") include("alefeld.jl") +# Batched Solver Support +include("batched/utils.jl") +include("batched/raphson.jl") +include("batched/dfsane.jl") +include("batched/broyden.jl") + import PrecompileTools PrecompileTools.@compile_workload begin @@ -74,5 +77,6 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld +export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane end # module diff --git a/src/batched/broyden.jl b/src/batched/broyden.jl new file mode 100644 index 0000000..ed3cd5d --- /dev/null +++ b/src/batched/broyden.jl @@ -0,0 +1,6 @@ +struct BatchedBroyden{TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + termination_condition::TC +end + +# Implementation of solve using Package Extensions diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl new file mode 100644 index 0000000..60bb6ae --- /dev/null +++ b/src/batched/dfsane.jl @@ -0,0 +1,141 @@ +Base.@kwdef struct BatchedSimpleDFSane{T, F, TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + σₘᵢₙ::T = 1.0f-10 + σₘₐₓ::T = 1.0f+10 + σ₁::T = 1.0f0 + M::Int = 10 + γ::T = 1.0f-4 + τₘᵢₙ::T = 0.1f0 + τₘₐₓ::T = 0.5f0 + nₑₓₚ::Int = 2 + ηₛ::F = (f₍ₙₒᵣₘ₎₁, n, xₙ, fₙ) -> f₍ₙₒᵣₘ₎₁ ./ n .^ 2 + termination_condition::TC = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing) + max_inner_iterations::Int = 1000 +end + +function SciMLBase.__solve(prob::NonlinearProblem, + alg::BatchedSimpleDFSane, + args...; + abstol = nothing, + reltol = nothing, + maxiters = 100, + kwargs...) + iip = isinplace(prob) + + u, f, reconstruct = _construct_batched_problem_structure(prob) + L, N = size(u) + T = eltype(u) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σₘᵢₙ), T(alg.σₘₐₓ), T(alg.γ), T(alg.τₘᵢₙ), T(alg.τₘₐₓ) + α₁ = one(T) + α₊, α₋ = similar(u, 1, N), similar(u, 1, N) + σₙ = fill(T(alg.σ₁), 1, N) + 𝒹 = similar(σₙ, L, N) + M = alg.M + nₑₓₚ = alg.nₑₓₚ + + xₙ, xₙ₋₁, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎ₙ = copy(u), copy(u), similar(u, 1, N), similar(u, 1, N) + + function ff!(fₓ, fₙₒᵣₘ, x) + f(fₓ, x) + sum!(abs2, fₙₒᵣₘ, fₓ) + fₙₒᵣₘ .^= (nₑₓₚ / 2) + return fₓ + end + + function ff!(fₙₒᵣₘ, x) + fₓ = f(x) + sum!(abs2, fₙₒᵣₘ, fₓ) + fₙₒᵣₘ .^= (nₑₓₚ / 2) + return fₓ + end + + @maybeinplace iip fₙ₋₁=ff!(f₍ₙₒᵣₘ₎ₙ₋₁, xₙ) xₙ + iip && (fₙ = similar(fₙ₋₁)) + ℋ = repeat(f₍ₙₒᵣₘ₎ₙ₋₁, M, 1) + f̄ = similar(ℋ, 1, N) + ηₛ = (n, xₙ, fₙ) -> alg.ηₛ(f₍ₙₒᵣₘ₎ₙ₋₁, n, xₙ, fₙ) + + for n in 1:maxiters + # Spectral parameter range check + @. σₙ = sign(σₙ) * clamp(abs(σₙ), σₘᵢₙ, σₘₐₓ) + + # Line search direction + @. 𝒹 = -σₙ * fₙ₋₁ + + η = ηₛ(n, xₙ₋₁, fₙ₋₁) + maximum!(f̄, ℋ) + fill!(α₊, α₁) + fill!(α₋, α₁) + @. xₙ = xₙ₋₁ + α₊ * 𝒹 + + @maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ) + + for _ in 1:(alg.max_inner_iterations) + 𝒸 = @. f̄ + η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ + + (sum(f₍ₙₒᵣₘ₎ₙ .≤ 𝒸) ≥ N ÷ 2) && break + + @. α₊ = clamp(α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), + τₘᵢₙ * α₊, + τₘₐₓ * α₊) + @. xₙ = xₙ₋₁ - α₋ * 𝒹 + @maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ) + + (sum(f₍ₙₒᵣₘ₎ₙ .≤ 𝒸) ≥ N ÷ 2) && break + + @. α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁), + τₘᵢₙ * α₋, + τₘₐₓ * α₋) + @. xₙ = xₙ₋₁ + α₊ * 𝒹 + @maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ) + end + + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode) + end + + # Update spectral parameter + @. xₙ₋₁ = xₙ - xₙ₋₁ + @. fₙ₋₁ = fₙ - fₙ₋₁ + + sum!(abs2, α₊, xₙ₋₁) + sum!(α₋, xₙ₋₁ .* fₙ₋₁) + σₙ .= α₊ ./ (α₋ .+ T(1e-5)) + + # Take step + @. xₙ₋₁ = xₙ + @. fₙ₋₁ = fₙ + @. f₍ₙₒᵣₘ₎ₙ₋₁ = f₍ₙₒᵣₘ₎ₙ + + # Update history + ℋ[n % M + 1, :] .= view(f₍ₙₒᵣₘ₎ₙ, 1, :) + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xₙ = storage.u + @maybeinplace iip fₙ=f(xₙ) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode = ReturnCode.MaxIters) +end diff --git a/src/batched/raphson.jl b/src/batched/raphson.jl new file mode 100644 index 0000000..323c07e --- /dev/null +++ b/src/batched/raphson.jl @@ -0,0 +1,77 @@ +struct BatchedSimpleNewtonRaphson{CS, AD, FDT, TC <: NLSolveTerminationCondition} <: + AbstractBatchedNonlinearSolveAlgorithm + termination_condition::TC +end + +alg_autodiff(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = AD +diff_type(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = FDT + +function BatchedSimpleNewtonRaphson(; chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + return BatchedSimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), + SciMLBase._unwrap_val(autodiff), + SciMLBase._unwrap_val(diff_type), typeof(termination_condition)}(termination_condition) +end + +function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson; + abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) + iip = SciMLBase.isinplace(prob) + @assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems." + u, f, reconstruct = _construct_batched_problem_structure(prob) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + storage = _get_storage(mode, u) + + xₙ, xₙ₋₁ = copy(u), copy(u) + T = eltype(u) + + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) + termination_condition = tc(storage) + + for i in 1:maxiters + if alg_autodiff(alg) + fₙ, 𝓙 = value_derivative(f, xₙ) + else + fₙ = f(xₙ) + 𝓙 = FiniteDiff.finite_difference_jacobian(f, xₙ, diff_type(alg), eltype(xₙ), fₙ) + end + + iszero(fₙ) && return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode = ReturnCode.Success) + + δx = reshape(𝓙 \ vec(fₙ), size(xₙ)) + xₙ .-= δx + + if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) + retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode) + end + + xₙ₋₁ .= xₙ + end + + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + xₙ = storage.u + fₙ = f(xₙ) + end + + return DiffEqBase.build_solution(prob, + alg, + reconstruct(xₙ), + reconstruct(fₙ); + retcode = ReturnCode.MaxIters) +end diff --git a/src/batched/utils.jl b/src/batched/utils.jl new file mode 100644 index 0000000..7b85011 --- /dev/null +++ b/src/batched/utils.jl @@ -0,0 +1,79 @@ +macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing} = nothing) + @assert expr.head == :(=) + x1, x2 = expr.args + @assert x2.head == :call + f, x... = x2.args + define_expr = u0 === nothing ? :() : :($(x1) = similar($(u0))) + return quote + if $(esc(iip)) + $(esc(define_expr)) + $(esc(f))($(esc(x1)), $(esc.(x)...)) + else + $(esc(expr)) + end + end +end + +function _get_tolerance(η, tc_η, ::Type{T}) where {T} + fallback_η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) + return ifelse(η !== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η)) +end + +function _construct_batched_problem_structure(prob) + return _construct_batched_problem_structure(prob.u0, + prob.f, + prob.p, + Val(SciMLBase.isinplace(prob))) +end + +function _construct_batched_problem_structure(u0::AbstractArray{T, N}, + f, + p, + ::Val{iip}) where {T, N, iip} + # Reconstruct `u` + reconstruct = N == 2 ? identity : Base.Fix2(reshape, size(u0)) + # Standardize `u` + standardize = N == 2 ? identity : + (N == 1 ? Base.Fix2(reshape, (:, 1)) : + Base.Fix2(reshape, (:, size(u0, ndims(u0))))) + # Updated Function + f_modified = if iip + function f_modified_iip(du, u) + f(reconstruct(du), reconstruct(u), p) + return standardize(du) + end + else + f_modified_oop(u) = standardize(f(reconstruct(u), p)) + end + return standardize(u0), f_modified, reconstruct +end + +@views function _init_𝓙(x::AbstractMatrix) + 𝓙 = ArrayInterface.zeromatrix(x[:, 1]) + if ismutable(x) + 𝓙[diagind(𝓙)] .= one(eltype(x)) + else + 𝓙 .+= I + end + return repeat(𝓙, 1, 1, size(x, 2)) +end + +_result_from_storage(::Nothing, xₙ, fₙ, args...) = ReturnCode.Success, xₙ, fₙ +function _result_from_storage(storage::NLSolveSafeTerminationResult, xₙ, fₙ, f, mode, iip) + if storage.return_code == DiffEqBase.NLSolveSafeTerminationReturnCode.Success + return ReturnCode.Success, xₙ, fₙ + else + if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES + @maybeinplace iip fₙ=f(xₙ) + return ReturnCode.Terminated, storage.u, fₙ + else + return ReturnCode.Terminated, xₙ, fₙ + end + end +end + +function _get_storage(mode, u) + return mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? + NLSolveSafeTerminationResult(mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES ? u : + nothing) : nothing +end diff --git a/src/broyden.jl b/src/broyden.jl index 8ce0d66..6c5c3ce 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -1,7 +1,8 @@ """ Broyden(; batched = false, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, reltol = nothing)) + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) A low-overhead implementation of Broyden. This method is non-allocating on scalar and static array problems. @@ -11,19 +12,23 @@ and static array problems. To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or `import NNlib` must be present in your code. """ -struct Broyden{batched, TC <: NLSolveTerminationCondition} <: +struct Broyden{TC <: NLSolveTerminationCondition} <: AbstractSimpleNonlinearSolveAlgorithm termination_condition::TC +end - function Broyden(; batched = false, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing)) - return new{batched, typeof(termination_condition)}(termination_condition) +function Broyden(; batched = false, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing)) + if batched + @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." + return BatchedBroyden(termination_condition) end + return Broyden(termination_condition) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; +function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) @@ -38,11 +43,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...; error("Broyden currently only supports out-of-place nonlinear problems") end - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES error("Broyden currently doesn't support SAFE_BEST termination modes") diff --git a/src/dfsane.jl b/src/dfsane.jl index b5e2b82..2e52cde 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -51,7 +51,7 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_ - `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the algorithm. Used exclusively in `batched` mode. Defaults to `1000`. """ -struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm +struct SimpleDFSane{T, TC} <: AbstractSimpleNonlinearSolveAlgorithm σ_min::T σ_max::T σ_1::T @@ -62,47 +62,54 @@ struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm nexp::Int η_strategy::Function termination_condition::TC - max_inner_iterations::Int +end - function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, - M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, - nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2, - termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing), - batched::Bool = false, - max_inner_iterations = 1000) - return new{batched, typeof(σ_min), typeof(termination_condition)}(σ_min, - σ_max, - σ_1, +function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, + M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, + nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), + batched::Bool = false, + max_inner_iterations = 1000) + if batched + return BatchedSimpleDFSane(; σₘᵢₙ = σ_min, + σₘₐₓ = σ_max, + σ₁ = σ_1, M, γ, - τ_min, - τ_max, - nexp, - η_strategy, + τₘᵢₙ = τ_min, + τₘₐₓ = τ_max, + nₑₓₚ = nexp, + ηₛ = η_strategy, termination_condition, max_inner_iterations) end + return SimpleDFSane{typeof(σ_min), typeof(termination_condition)}(σ_min, + σ_max, + σ_1, + M, + γ, + τ_min, + τ_max, + nexp, + η_strategy, + termination_condition) end -function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, +function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - kwargs...) where {batched} + kwargs...) tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) f = Base.Fix2(prob.f, prob.p) x = float(prob.u0) - if batched - batch_size = size(x, 2) - end - T = eltype(x) σ_min = float(alg.σ_min) σ_max = float(alg.σ_max) - σ_k = batched ? fill(float(alg.σ_1), 1, batch_size) : float(alg.σ_1) + σ_k = float(alg.σ_1) M = alg.M γ = float(alg.γ) @@ -111,17 +118,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, nexp = alg.nexp η_strategy = alg.η_strategy - batched && @assert ndims(x)==2 "Batched SimpleDFSane only supports 2D arrays" - if SciMLBase.isinplace(prob) error("SimpleDFSane currently only supports out-of-place nonlinear problems") end - atol = abstol !== nothing ? abstol : - (tc.abstol !== nothing ? tc.abstol : - real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)) - rtol = reltol !== nothing ? reltol : - (tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5)) + atol = _get_tolerance(abstol, tc.abstol, T) + rtol = _get_tolerance(reltol, tc.reltol, T) if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES error("SimpleDFSane currently doesn't support SAFE_BEST termination modes") @@ -133,22 +135,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, function ff(x) F = f(x) - f_k = if batched - sum(abs2, F; dims = 1) .^ (nexp / 2) - else - norm(F)^nexp - end + f_k = norm(F)^nexp return f_k, F end function generate_history(f_k, M) - if batched - history = similar(f_k, (M, length(f_k))) - history .= reshape(f_k, 1, :) - return history - else - return fill(f_k, M) - end + return fill(f_k, M) end f_k, F_k = ff(x) @@ -158,17 +150,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, for k in 1:maxiters # Spectral parameter range check - if batched - @. σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) - else - σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) - end + σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max) # Line search direction d = -σ_k .* F_k η = η_strategy(f_1, k, x, F_k) - f̄ = batched ? maximum(history_f_k; dims = 1) : maximum(history_f_k) + f̄ = maximum(history_f_k) α_p = α_1 α_m = α_1 x_new = @. x + α_p * d @@ -179,38 +167,20 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, while true inner_iterations += 1 - if batched - criteria = @. f̄ + η - γ * α_p^2 * f_k - # NOTE: This is simply a heuristic, ideally we check using `all` but that is - # typically very expensive for large problems - (sum(f_new .≤ criteria) ≥ batch_size ÷ 2) && break - else - criteria = f̄ + η - γ * α_p^2 * f_k - f_new ≤ criteria && break - end + criteria = f̄ + η - γ * α_p^2 * f_k + f_new ≤ criteria && break α_tp = @. α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k) x_new = @. x - α_m * d f_new, F_new = ff(x_new) - if batched - # NOTE: This is simply a heuristic, ideally we check using `all` but that is - # typically very expensive for large problems - (sum(f_new .≤ criteria) ≥ batch_size ÷ 2) && break - else - f_new ≤ criteria && break - end + f_new ≤ criteria && break α_tm = @. α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k) α_p = @. clamp(α_tp, τ_min * α_p, τ_max * α_p) α_m = @. clamp(α_tm, τ_min * α_m, τ_max * α_m) x_new = @. x + α_p * d f_new, F_new = ff(x_new) - - # NOTE: The original algorithm runs till either condition is satisfied, however, - # for most batched problems like neural networks we only care about - # approximate convergence - batched && (inner_iterations ≥ alg.max_inner_iterations) && break end if termination_condition(F_new, x_new, x, atol, rtol) @@ -225,11 +195,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, s_k = @. x_new - x y_k = @. F_new - F_k - if batched - σ_k = sum(abs2, s_k; dims = 1) ./ (sum(s_k .* y_k; dims = 1) .+ T(1e-5)) - else - σ_k = (s_k' * s_k) / (s_k' * y_k) - end + σ_k = (s_k' * s_k) / (s_k' * y_k) # Take step x = x_new @@ -237,11 +203,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched}, f_k = f_new # Store function value - if batched - history_f_k[k % M + 1, :] .= vec(f_new) - else - history_f_k[k % M + 1] = f_new - end + history_f_k[k % M + 1] = f_new end return SciMLBase.build_solution(prob, alg, x, F_k; retcode = ReturnCode.MaxIters) end diff --git a/src/raphson.jl b/src/raphson.jl index 386c350..48b8f75 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,8 +1,9 @@ """ -```julia -SimpleNewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(), - diff_type = Val{:forward}) -``` + SimpleNewtonRaphson(; batched = false, + chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = missing) A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar and static array problems. @@ -27,13 +28,37 @@ and static array problems. - `diff_type`: the type of finite differencing used if `autodiff = false`. Defaults to `Val{:forward}` for forward finite differences. For more details on the choices, see the [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) documentation. +- `termination_condition`: control the termination of the algorithm. (Only works for batched + problems) """ -struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} - function SimpleNewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(), - diff_type = Val{:forward}) - new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff), +struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end + +function SimpleNewtonRaphson(; batched = false, + chunk_size = Val{0}(), + autodiff = Val{true}(), + diff_type = Val{:forward}, + termination_condition = missing) + if !ismissing(termination_condition) && !batched + throw(ArgumentError("`termination_condition` is currently only supported for batched problems")) + end + if batched + # @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." + termination_condition = ismissing(termination_condition) ? + NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing) : + termination_condition + return BatchedSimpleNewtonRaphson(; chunk_size, + autodiff, + diff_type, + termination_condition) + return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), + SciMLBase._unwrap_val(autodiff), SciMLBase._unwrap_val(diff_type)}() end + return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size), + SciMLBase._unwrap_val(autodiff), + SciMLBase._unwrap_val(diff_type)}() end function SciMLBase.__solve(prob::NonlinearProblem, diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..469f302 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,10 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/basictests.jl b/test/basictests.jl index aea1731..d5f85cd 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,16 +1,14 @@ -using SimpleNonlinearSolve -using StaticArrays -using BenchmarkTools -using DiffEqBase -using LinearAlgebra -using Test - -const BATCHED_BROYDEN_SOLVERS = Broyden[] -const BROYDEN_SOLVERS = Broyden[] -const BATCHED_LBROYDEN_SOLVERS = LBroyden[] -const LBROYDEN_SOLVERS = LBroyden[] -const BATCHED_DFSANE_SOLVERS = SimpleDFSane[] -const DFSANE_SOLVERS = SimpleDFSane[] +using SimpleNonlinearSolve, + StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, + NNlib + +const BATCHED_BROYDEN_SOLVERS = [] +const BROYDEN_SOLVERS = [] +const BATCHED_LBROYDEN_SOLVERS = [] +const LBROYDEN_SOLVERS = [] +const BATCHED_DFSANE_SOLVERS = [] +const DFSANE_SOLVERS = [] +const BATCHED_RAPHSON_SOLVERS = [] for mode in instances(NLSolveTerminationMode.T) if mode ∈ @@ -27,6 +25,12 @@ for mode in instances(NLSolveTerminationMode.T) push!(BATCHED_LBROYDEN_SOLVERS, LBroyden(; batched = true, termination_condition)) push!(DFSANE_SOLVERS, SimpleDFSane(; batched = false, termination_condition)) push!(BATCHED_DFSANE_SOLVERS, SimpleDFSane(; batched = true, termination_condition)) + push!(BATCHED_RAPHSON_SOLVERS, + SimpleNewtonRaphson(; batched = true, + termination_condition)) + push!(BATCHED_RAPHSON_SOLVERS, + SimpleNewtonRaphson(; batched = true, autodiff = false, + termination_condition)) end # SimpleNewtonRaphson @@ -476,9 +480,6 @@ for options in list_of_options @test all(abs.(f(u, p)) .< 1e-10) end -# Batched Broyden -using NNlib - f, u0 = (u, p) -> u .* u .- p, randn(1, 3) p = [2.0 1.0 5.0]; @@ -488,9 +489,10 @@ sol = solve(probN, Broyden(batched = true)) @test abs.(sol.u) ≈ sqrt.(p) -for alg in (BATCHED_BROYDEN_SOLVERS..., +@testset "Batched Solver: $(nameof(typeof(alg)))" for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS..., - BATCHED_DFSANE_SOLVERS...) + BATCHED_DFSANE_SOLVERS..., + BATCHED_RAPHSON_SOLVERS...) sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3) @test sol.retcode == ReturnCode.Success diff --git a/test/inplace.jl b/test/inplace.jl new file mode 100644 index 0000000..886e820 --- /dev/null +++ b/test/inplace.jl @@ -0,0 +1,53 @@ +using SimpleNonlinearSolve, + StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, + NNlib + +# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane +function f!(du::AbstractArray{<:Number, N}, + u::AbstractArray{<:Number, N}, + p::AbstractVector) where {N} + u_ = reshape(u, :, size(u, N)) + du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :), + ntuple(_ -> 1, N - 1)..., + size(u, N)) + return du +end + +function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector) + du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :) + return du +end + +function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector) + du .= sum(abs2, u) .- p + return du +end + +@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true), + SimpleDFSane(batched = true)) + @testset "T: $T" for T in (Float32, Float64) + p = rand(T, 5) + @testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5)) + u0 = ones(T, sz) + prob = NonlinearProblem{true}(f!, u0, p) + + sol = solve(prob, solver) + + @test SciMLBase.successful_retcode(sol.retcode) + + @test sol.resid≈zero(sol.resid) atol=5e-3 + end + + p = rand(T, 1) + @testset "size(u0): $sz" for sz in ((3,), (5,), (10,)) + u0 = ones(T, sz) + prob = NonlinearProblem{true}(f!, u0, p) + + sol = solve(prob, solver) + + @test SciMLBase.successful_retcode(sol.retcode) + + @test sol.resid≈zero(sol.resid) atol=5e-3 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 94a0086..bea57ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,15 @@ -using Pkg using SafeTestsets -const LONGER_TESTS = false const GROUP = get(ENV, "GROUP", "All") -const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") @time begin if GROUP == "All" || GROUP == "Core" @time @safetestset "Basic Tests + Some AD" begin include("basictests.jl") end + + @time @safetestset "Inplace Tests" begin + include("inplace.jl") + end end end