-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul #54303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fda8c87
3073352
54a4038
108db04
a2dd315
5db39fb
7a8bb75
1e06a1e
cefa1b1
28d4fd8
b24cc1e
9e66d1a
541a76d
31692d5
311382b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -364,28 +364,32 @@ lmul!(A, B) | |
| # aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly | ||
| Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, | ||
| _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} | ||
| if all(in(('N', 'T', 'C')), (tA, tB)) | ||
| if tA == 'T' && tB == 'N' && A === B | ||
| # We convert the chars to uppercase to potentially unwrap a WrapperChar, | ||
| # and extract the char corresponding to the wrapper type | ||
| tA_uc, tB_uc = uppercase(tA), uppercase(tB) | ||
| # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is true, this should be a giant contribution to the reduction of compile times, right? If we land in this branch, then we don't need to compile symm and hemm, or in the other case syrk/herk/gemm_wrapper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is some compile-time improvement indeed, although it's not dramatic. julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);
julia> @time mul!(C, A, B);
0.847057 seconds (3.39 M allocations: 171.963 MiB, 24.97% gc time, 100.00% compilation time) # nightly
0.757433 seconds (3.94 M allocations: 202.922 MiB, 4.65% gc time, 100.00% compilation time) # This PR
julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);
julia> @time mul!(C, A, B);
1.098831 seconds (3.68 M allocations: 189.159 MiB, 24.52% gc time, 99.99% compilation time) # nightly
0.687847 seconds (4.72 M allocations: 238.864 MiB, 7.04% gc time, 99.99% compilation time) # This PRDescending into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);
julia> @code_typed mul!(C, A, B)
CodeInfo(
1 ─ %1 = invoke LinearAlgebra.gemm_wrapper!(C::Matrix{Float64}, 'N'::Char, 'N'::Char, A::Matrix{Float64}, B::Matrix{Float64}, $(QuoteNode(LinearAlgebra.MulAddMul{true, true, Bool, Bool}(true, false)))::LinearAlgebra.MulAddMul{true, true, Bool, Bool})::Matrix{Float64}
└── return %1
) => Matrix{Float64}I'm not certain why there's a compile-time improvement here. (perhaps noise?) In this case, the The second case ( julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);
julia> @code_typed mul!(C, A, B)
CodeInfo(
1 ── %1 = Base.getfield(B, :uplo)::Char
│ %2 = Base.bitcast(Base.UInt32, %1)::UInt32
│ %3 = Base.bitcast(Base.UInt32, 'U')::UInt32
│ %4 = (%2 === %3)::Bool
│ %5 = Base.getfield(B, :data)::Matrix{Float64}
└─── goto #3 if not %4
2 ── goto #4
3 ── goto #4
4 ┄─ %9 = φ (#2 => 'S', #3 => 's')::Char
│ %10 = Base.bitcast(Base.UInt32, %9)::UInt32
│ %11 = Base.bitcast(Base.UInt32, 'S')::UInt32
│ %12 = (%10 === %11)::Bool
└─── goto #5
5 ── goto #7 if not %12
6 ── goto #8
7 ── nothing::Nothing
8 ┄─ %17 = φ (#6 => 'U', #7 => 'L')::Char
│ %18 = invoke LinearAlgebra.BLAS.symm!('R'::Char, %17::Char, 1.0::Float64, %5::Matrix{Float64}, A::Matrix{Float64}, 0.0::Float64, C::Matrix{Float64})::Matrix{Float64}
└─── goto #9
9 ── goto #10
10 ─ goto #11
11 ─ return %18
) => Matrix{Float64}The |
||
| if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) | ||
| if tA_uc == 'T' && tB_uc == 'N' && A === B | ||
| return syrk_wrapper!(C, 'T', A, _add) | ||
| elseif tA == 'N' && tB == 'T' && A === B | ||
| elseif tA_uc == 'N' && tB_uc == 'T' && A === B | ||
| return syrk_wrapper!(C, 'N', A, _add) | ||
| elseif tA == 'C' && tB == 'N' && A === B | ||
| elseif tA_uc == 'C' && tB_uc == 'N' && A === B | ||
| return herk_wrapper!(C, 'C', A, _add) | ||
| elseif tA == 'N' && tB == 'C' && A === B | ||
| elseif tA_uc == 'N' && tB_uc == 'C' && A === B | ||
| return herk_wrapper!(C, 'N', A, _add) | ||
| else | ||
| return gemm_wrapper!(C, tA, tB, A, B, _add) | ||
| end | ||
| end | ||
| alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
| if alpha isa Union{Bool,T} && beta isa Union{Bool,T} | ||
| if (tA == 'S' || tA == 's') && tB == 'N' | ||
| if tA_uc == 'S' && tB_uc == 'N' | ||
| return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) | ||
| elseif (tB == 'S' || tB == 's') && tA == 'N' | ||
| elseif tA_uc == 'N' && tB_uc == 'S' | ||
| return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) | ||
| elseif (tA == 'H' || tA == 'h') && tB == 'N' | ||
| elseif tA_uc == 'H' && tB_uc == 'N' | ||
| return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) | ||
| elseif (tB == 'H' || tB == 'h') && tA == 'N' | ||
| elseif tA_uc == 'N' && tB_uc == 'H' | ||
| return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) | ||
| end | ||
| end | ||
|
|
@@ -395,7 +399,11 @@ end | |
| # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. | ||
| Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, | ||
| _add::MulAddMul=MulAddMul()) where {T<:BlasReal} | ||
| if all(in(('N', 'T', 'C')), (tA, tB)) | ||
| # We convert the chars to uppercase to potentially unwrap a WrapperChar, | ||
| # and extract the char corresponding to the wrapper type | ||
| tA_uc, tB_uc = uppercase(tA), uppercase(tB) | ||
| # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. | ||
| if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) | ||
| gemm_wrapper!(C, tA, tB, A, B, _add) | ||
| else | ||
| _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) | ||
|
|
@@ -434,18 +442,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar | |
| mA == 0 && return y | ||
| nA == 0 && return _rmul_or_fill!(y, β) | ||
| alpha, beta = promote(α, β, zero(T)) | ||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
| if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && | ||
| stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && | ||
| !iszero(stride(x, 1)) && # We only check input's stride here. | ||
| if tA in ('N', 'T', 'C') | ||
| if tA_uc in ('N', 'T', 'C') | ||
| return BLAS.gemv!(tA, alpha, A, x, beta, y) | ||
| elseif tA in ('S', 's') | ||
| elseif tA_uc == 'S' | ||
| return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) | ||
| elseif tA in ('H', 'h') | ||
| elseif tA_uc == 'H' | ||
| return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) | ||
| end | ||
| end | ||
| if tA in ('S', 's', 'H', 'h') | ||
| if tA_uc in ('S', 'H') | ||
| # re-wrap again and use plain ('N') matvec mul algorithm, | ||
| # because _generic_matvecmul! can't handle the HermOrSym cases specifically | ||
| return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) | ||
|
|
@@ -464,14 +473,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs | |
| mA == 0 && return y | ||
| nA == 0 && return _rmul_or_fill!(y, β) | ||
| alpha, beta = promote(α, β, zero(T)) | ||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
| if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && | ||
| stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && | ||
| stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y` | ||
| stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y` | ||
| !iszero(stride(x, 1)) | ||
| BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) | ||
| return y | ||
| else | ||
| Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) | ||
| Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) | ||
| return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) | ||
| end | ||
| end | ||
|
|
@@ -487,15 +497,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs | |
| mA == 0 && return y | ||
| nA == 0 && return _rmul_or_fill!(y, β) | ||
| alpha, beta = promote(α, β, zero(T)) | ||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
| @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && | ||
| stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && | ||
| !iszero(stride(x, 1)) && tA in ('N', 'T', 'C') | ||
| !iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C') | ||
| xfl = reinterpret(reshape, T, x) # Use reshape here. | ||
| yfl = reinterpret(reshape, T, y) | ||
| BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) | ||
| BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) | ||
| return y | ||
| elseif tA in ('S', 's', 'H', 'h') | ||
| elseif tA_uc in ('S', 'H') | ||
| # re-wrap again and use plain ('N') matvec mul algorithm, | ||
| # because _generic_matvecmul! can't handle the HermOrSym cases specifically | ||
| return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) | ||
|
|
@@ -504,10 +515,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs | |
| end | ||
| end | ||
|
|
||
| function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, | ||
| # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it | ||
| # to be concretely inferred | ||
| Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, | ||
| _add = MulAddMul()) where {T<:BlasFloat} | ||
| nC = checksquare(C) | ||
| if tA == 'T' | ||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
| if tA_uc == 'T' | ||
| (nA, mA) = size(A,1), size(A,2) | ||
| tAt = 'N' | ||
| else | ||
|
|
@@ -542,10 +556,13 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat | |
| return gemm_wrapper!(C, tA, tAt, A, A, _add) | ||
| end | ||
|
|
||
| function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, | ||
| # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it | ||
| # to be concretely inferred | ||
| Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, | ||
| _add = MulAddMul()) where {T<:BlasReal} | ||
| nC = checksquare(C) | ||
| if tA == 'C' | ||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
| if tA_uc == 'C' | ||
| (nA, mA) = size(A,1), size(A,2) | ||
| tAt = 'N' | ||
| else | ||
|
|
@@ -581,20 +598,28 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA | |
| return gemm_wrapper!(C, tA, tAt, A, A, _add) | ||
| end | ||
|
|
||
| function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, | ||
| # Aggressive constprop helps propagate the values of tA and tB into wrap, which | ||
| # makes the calls concretely inferred | ||
| Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, | ||
dkarrasch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| A::StridedVecOrMat{T}, | ||
| B::StridedVecOrMat{T}) where {T<:BlasFloat} | ||
| mA, nA = lapack_size(tA, A) | ||
| mB, nB = lapack_size(tB, B) | ||
| C = similar(B, T, mA, nB) | ||
| if all(in(('N', 'T', 'C')), (tA, tB)) | ||
| # We convert the chars to uppercase to potentially unwrap a WrapperChar, | ||
| # and extract the char corresponding to the wrapper type | ||
| tA_uc, tB_uc = uppercase(tA), uppercase(tB) | ||
| # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. | ||
| if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) | ||
| gemm_wrapper!(C, tA, tB, A, B) | ||
| else | ||
| _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) | ||
| end | ||
| end | ||
|
|
||
| function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, | ||
| # Aggressive constprop helps propagate the values of tA and tB into wrap, which | ||
| # makes the calls concretely inferred | ||
| Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, | ||
| A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, | ||
| _add = MulAddMul()) where {T<:BlasFloat} | ||
| mA, nA = lapack_size(tA, A) | ||
|
|
@@ -634,7 +659,9 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar | |
| _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) | ||
| end | ||
|
|
||
| function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, | ||
| # Aggressive constprop helps propagate the values of tA and tB into wrap, which | ||
| # makes the calls concretely inferred | ||
| Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, | ||
| A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, | ||
| _add = MulAddMul()) where {T<:BlasReal} | ||
| mA, nA = lapack_size(tA, A) | ||
|
|
@@ -664,13 +691,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs | |
|
|
||
| alpha, beta = promote(_add.alpha, _add.beta, zero(T)) | ||
|
|
||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
|
|
||
| # Make-sure reinterpret-based optimization is BLAS-compatible. | ||
| if (alpha isa Union{Bool,T} && | ||
| beta isa Union{Bool,T} && | ||
| stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && | ||
| stride(A, 2) >= size(A, 1) && | ||
| stride(B, 2) >= size(B, 1) && | ||
| stride(C, 2) >= size(C, 1) && tA == 'N') | ||
| stride(C, 2) >= size(C, 1) && tA_uc == 'N') | ||
| BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) | ||
| return C | ||
| end | ||
|
|
@@ -703,9 +732,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and | |
| See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref). | ||
| """ | ||
| function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) | ||
| if tM == 'N' | ||
| tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char | ||
| if tM_uc == 'N' | ||
| copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src) | ||
| elseif tM == 'T' | ||
| elseif tM_uc == 'T' | ||
| copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src) | ||
| else | ||
| copy_adjoint!(B, ir_dest, jr_dest, M, jr_src, ir_src) | ||
|
|
@@ -734,11 +764,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and | |
| See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref). | ||
| """ | ||
| function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) | ||
| if tM == 'N' | ||
| tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char | ||
| if tM_uc == 'N' | ||
| copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src) | ||
| else | ||
| copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src) | ||
| tM == 'C' && conj!(@view B[ir_dest, jr_dest]) | ||
| tM_uc == 'C' && conj!(@view B[ir_dest, jr_dest]) | ||
| end | ||
| B | ||
| end | ||
|
|
@@ -751,7 +782,8 @@ end | |
|
|
||
| @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, | ||
| _add::MulAddMul = MulAddMul()) | ||
| Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) | ||
| tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char | ||
| Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) | ||
| return _generic_matvecmul!(C, ta, Anew, B, _add) | ||
| end | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.