diff --git a/Project.toml b/Project.toml index a6080955a..98c82d698 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,8 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32" +libflame_jll = "8e9d65e3-b2b8-5a9c-baa2-617b4576f0b9" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" @@ -47,6 +49,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac" [extensions] +LinearSolveBLISExt = ["blis_jll", "libflame_jll"] LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" @@ -77,8 +80,8 @@ ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" -ExplicitImports = "1" EnzymeCore = "0.8.1" +ExplicitImports = "1" FastAlmostBandedMatrices = "0.1" FastLapackInterface = "2" FiniteDiff = "2.22" @@ -118,14 +121,16 @@ StaticArraysCore = "1.4.2" Test = "1" UnPack = "1" Zygote = "0.7" +blis_jll = "0.9.0" julia = "1.10" +libflame_jll = "5.2.0" [extras] AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -150,6 +155,8 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32" +libflame_jll = "8e9d65e3-b2b8-5a9c-baa2-617b4576f0b9" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface", "SparseArrays", "ExplicitImports"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "blis_jll", "libflame_jll", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface", "SparseArrays", "ExplicitImports"] diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index 02825b3fa..eb51cb8f2 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -16,10 +16,12 @@ the best choices, with SVD being the slowest but most precise. For efficiency, `RFLUFactorization` is the fastest for dense LU-factorizations until around 150x150 matrices, though this can be dependent on the exact details of the hardware. After this -point, `MKLLUFactorization` is usually faster on most hardware. Note that on Mac computers -that `AppleAccelerateLUFactorization` is generally always the fastest. `LUFactorization` will -use your base system BLAS which can be fast or slow depending on the hardware configuration. -`SimpleLUFactorization` will be fast only on very small matrices but can cut down on compile times. +point, `MKLLUFactorization` is usually faster on most hardware. `BLISLUFactorization` provides +another high-performance option that combines optimized BLAS operations from BLIS with optimized LAPACK routines from libflame. +Note that on Mac computers that `AppleAccelerateLUFactorization` is generally always the fastest. +`LUFactorization` will use your base system BLAS which can be fast or slow depending on the hardware +configuration. `SimpleLUFactorization` will be fast only on very small matrices but can cut down on +compile times. For very large dense factorizations, offloading to the GPU can be preferred. Metal.jl can be used on Mac hardware to offload, and has a cutoff point of being faster at around size 20,000 x 20,000 @@ -185,6 +187,18 @@ KrylovJL MKLLUFactorization ``` +### BLIS.jl + +!!! note + + Using this solver requires that both blis_jll and libflame_jll packages are available. + The solver will be automatically available when both packages are loaded, i.e., + `using blis_jll, libflame_jll`. + +```@docs +BLISLUFactorization +``` + ### AppleAccelerate.jl !!! note diff --git a/ext/LinearSolveBLISExt.jl b/ext/LinearSolveBLISExt.jl new file mode 100644 index 000000000..3be8a8934 --- /dev/null +++ b/ext/LinearSolveBLISExt.jl @@ -0,0 +1,278 @@ +""" +LinearSolveBLISExt + +Extension module that provides BLIS (BLAS-like Library Instantiation Software) integration +for LinearSolve.jl. This extension combines BLIS for optimized BLAS operations with +libflame for optimized LAPACK operations, providing a fully optimized linear algebra +backend. + +Key features: +- Uses BLIS for BLAS operations (matrix multiplication, etc.) +- Uses libflame for LAPACK operations (LU factorization, solve, etc.) +- Supports all standard numeric types (Float32/64, ComplexF32/64) +- Follows MKL-style ccall patterns for consistency +""" +module LinearSolveBLISExt + +using Libdl +using blis_jll +using libflame_jll +using LinearAlgebra +using LinearSolve + +using LinearAlgebra: BlasInt, LU, libblastrampoline +using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, + @blasfunc, chkargsok +using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase, do_factorization + +const global libblis = blis_jll.blis +const global libflame = libflame_jll.libflame + +""" + LinearSolve.do_factorization(alg::BLISLUFactorization, A, b, u) + +Perform LU factorization using BLIS for the underlying BLAS operations. +This method converts the matrix to a standard format and calls the BLIS-backed getrf! routine. +""" +function LinearSolve.do_factorization(alg::BLISLUFactorization, A, b, u) + A = convert(AbstractMatrix, A) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + info = Ref{BlasInt}() + A, ipiv, info_val, info_ref = getrf!(A; ipiv=ipiv, info=info) + return LU(A, ipiv, info_val) +end + +function getrf!(A::AbstractMatrix{<:ComplexF64}; + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), + info = Ref{BlasInt}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + end + ccall(("zgetrf_", libflame), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + +function getrf!(A::AbstractMatrix{<:ComplexF32}; + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), + info = Ref{BlasInt}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + end + ccall(("cgetrf_", libflame), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + +function getrf!(A::AbstractMatrix{<:Float64}; + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), + info = Ref{BlasInt}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + end + ccall(("dgetrf_", libflame), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + +function getrf!(A::AbstractMatrix{<:Float32}; + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), + info = Ref{BlasInt}(), + check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1, stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))) + end + ccall(("sgetrf_", libflame), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + +function getrs!(trans::AbstractChar, + A::AbstractMatrix{<:ComplexF64}, + ipiv::AbstractVector{BlasInt}, + B::AbstractVecOrMat{<:ComplexF64}; + info = Ref{BlasInt}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall((@blasfunc(zgetrs_), libblis), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + +function getrs!(trans::AbstractChar, + A::AbstractMatrix{<:ComplexF32}, + ipiv::AbstractVector{BlasInt}, + B::AbstractVecOrMat{<:ComplexF32}; + info = Ref{BlasInt}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall((@blasfunc(cgetrs_), libblis), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + +function getrs!(trans::AbstractChar, + A::AbstractMatrix{<:Float64}, + ipiv::AbstractVector{BlasInt}, + B::AbstractVecOrMat{<:Float64}; + info = Ref{BlasInt}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall((@blasfunc(dgetrs_), libblis), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + +function getrs!(trans::AbstractChar, + A::AbstractMatrix{<:Float32}, + ipiv::AbstractVector{BlasInt}, + B::AbstractVecOrMat{<:Float32}; + info = Ref{BlasInt}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall((@blasfunc(sgetrs_), libblis), Cvoid, + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), + trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info, + 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + +default_alias_A(::BLISLUFactorization, ::Any, ::Any) = false +default_alias_b(::BLISLUFactorization, ::Any, ::Any) = false + +const PREALLOCATED_BLIS_LU = begin + A = rand(0, 0) + luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}() +end + +function LinearSolve.init_cacheval(alg::BLISLUFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + PREALLOCATED_BLIS_LU +end + +function LinearSolve.init_cacheval(alg::BLISLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + A = rand(eltype(A), 0, 0) + ArrayInterface.lu_instance(A), Ref{BlasInt}() +end + +function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization; + kwargs...) + A = cache.A + A = convert(AbstractMatrix, A) + if cache.isfresh + cacheval = @get_cacheval(cache, :BLISLUFactorization) + res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2]) + fact = LU(res[1:3]...), res[4] + cache.cacheval = fact + cache.isfresh = false + end + + y = ldiv!(cache.u, @get_cacheval(cache, :BLISLUFactorization)[1], cache.b) + SciMLBase.build_linear_solution(alg, y, nothing, cache) + + #= + A, info = @get_cacheval(cache, :BLISLUFactorization) + LinearAlgebra.require_one_based_indexing(cache.u, cache.b) + m, n = size(A, 1), size(A, 2) + if m > n + Bc = copy(cache.b) + getrs!('N', A.factors, A.ipiv, Bc; info) + return copyto!(cache.u, 1, Bc, 1, n) + else + copyto!(cache.u, cache.b) + getrs!('N', A.factors, A.ipiv, cache.u; info) + end + + SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) + =# +end + +end \ No newline at end of file diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 938e1bd11..7c6e1291c 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -439,3 +439,42 @@ A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pr to avoid allocations and automatically offloads to the GPU. """ struct MetalLUFactorization <: AbstractFactorization end + +""" +```julia +BLISLUFactorization() +``` + +A wrapper over BLIS (BLAS-like Library Instantiation Software) for high-performance +BLAS operations combined with libflame for optimized LAPACK operations. This provides +a fully optimized linear algebra stack with both high-performance BLAS and LAPACK routines. + +BLIS provides highly optimized BLAS routines that can outperform reference BLAS +implementations, especially for certain matrix sizes and operations. libflame provides +optimized LAPACK operations that complement BLIS. The integration uses BLIS for BLAS +operations (like matrix multiplication) and libflame for LAPACK operations (like LU +factorization and solve). + +!!! note + + Using this solver requires that both blis_jll and libflame_jll packages are available. + The solver will be automatically available when both packages are loaded, i.e., + `using blis_jll, libflame_jll`. + +## Performance Characteristics + +- **Strengths**: Optimized BLAS operations, good performance on modern hardware +- **Use cases**: General dense linear systems where BLAS optimization matters +- **Compatibility**: Works with all numeric types (Float32/64, Complex32/64) + +## Example + +```julia +using LinearSolve, blis_jll, libflame_jll +A = rand(100, 100) +b = rand(100) +prob = LinearProblem(A, b) +sol = solve(prob, BLISLUFactorization()) +``` +""" +struct BLISLUFactorization <: AbstractFactorization end diff --git a/test/basictests.jl b/test/basictests.jl index a0dafe1e8..d240d203f 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -2,6 +2,9 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff using SciMLOperators, RecursiveFactorization, Sparspak, FastLapackInterface using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners using Test + +# Import JLL packages for extensions +using blis_jll, libflame_jll import Random const Dual64 = ForwardDiff.Dual{Nothing, Float64, 1} @@ -227,6 +230,13 @@ end if LinearSolve.usemkl push!(test_algs, MKLLUFactorization()) end + + # Add BLIS when the extension is available + try + push!(test_algs, LinearSolve.BLISLUFactorization()) + catch + # BLIS extension not available, skip + end @testset "Concrete Factorizations" begin for alg in test_algs