Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.
22 changes: 7 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,46 +1,38 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.16"
version = "0.1.17"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[extensions]
SimpleBatchedNonlinearSolveExt = "NNlib"
SimpleNonlinearSolveNNlibExt = "NNlib"

[compat]
ArrayInterface = "6, 7"
DiffEqBase = "6.123.0"
DiffEqBase = "6.126"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
NNlib = "0.8, 0.9"
PackageExtensionCompat = "1"
PrecompileTools = "1"
Reexport = "0.2, 1"
Requires = "1"
SciMLBase = "1.73"
PrecompileTools = "1"
StaticArraysCore = "1.4"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]
90 changes: 0 additions & 90 deletions ext/SimpleBatchedNonlinearSolveExt.jl

This file was deleted.

81 changes: 81 additions & 0 deletions ext/SimpleNonlinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
module SimpleNonlinearSolveNNlibExt

using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace

function __init__()
SimpleNonlinearSolve.NNlibExtLoaded[] = true
return
end

@views function SciMLBase.__solve(prob::NonlinearProblem,
alg::BatchedBroyden;
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)

u, f, reconstruct = _construct_batched_problem_structure(prob)
L, N = size(u)

tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

storage = _get_storage(mode, u)

xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4)
T = eltype(u)

atol = _get_tolerance(abstol, tc.abstol, T)
rtol = _get_tolerance(reltol, tc.reltol, T)
termination_condition = tc(storage)

𝓙⁻¹ = _init_𝓙(xₙ) # L × L × N
𝓙⁻¹f, xᵀ𝓙⁻¹δf, xᵀ𝓙⁻¹ = similar(𝓙⁻¹, L, N), similar(𝓙⁻¹, 1, N), similar(𝓙⁻¹, 1, L, N)

@maybeinplace iip fₙ₋₁=f(xₙ) u
iip && (fₙ = copy(fₙ₋₁))
for n in 1:maxiters
batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(fₙ₋₁, L, 1, N))
xₙ .= xₙ₋₁ .- 𝓙⁻¹f

@maybeinplace iip fₙ=f(xₙ)
δx .= xₙ .- xₙ₋₁
δf .= fₙ .- fₙ₋₁

batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N))
δxᵀ = reshape(δx, 1, L, N)

batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N))
batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹)
δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
return DiffEqBase.build_solution(prob,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode)
end

xₙ₋₁ .= xₙ
fₙ₋₁ .= fₙ
end

if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES
xₙ = storage.u
@maybeinplace iip fₙ=f(xₙ)
end

return DiffEqBase.build_solution(prob,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode = ReturnCode.MaxIters)
end

end
22 changes: 13 additions & 9 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,19 @@ using DiffEqBase

@reexport using SciMLBase

if !isdefined(Base, :get_extension)
using Requires
end

using PackageExtensionCompat
function __init__()
@static if !isdefined(Base, :get_extension)
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
end
end
@require_extensions
end

const NNlibExtLoaded = Ref{Bool}(false)

abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractBatchedNonlinearSolveAlgorithm <:
AbstractSimpleNonlinearSolveAlgorithm end

include("utils.jl")
include("bisection.jl")
Expand All @@ -42,6 +39,12 @@ include("ad.jl")
include("halley.jl")
include("alefeld.jl")

# Batched Solver Support
include("batched/utils.jl")
include("batched/raphson.jl")
include("batched/dfsane.jl")
include("batched/broyden.jl")

import PrecompileTools

PrecompileTools.@compile_workload begin
Expand Down Expand Up @@ -74,5 +77,6 @@ end
# DiffEq styled algorithms
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane

end # module
6 changes: 6 additions & 0 deletions src/batched/broyden.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
struct BatchedBroyden{TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
termination_condition::TC
end

# Implementation of solve using Package Extensions
Loading