From fda8c87f250dbfa726da36a1d88acb2b66f9ac01 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 29 Apr 2024 17:14:33 +0530 Subject: [PATCH 01/15] LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 57 ++++++++++++++++------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 6ed272ab42f02..6cb6f19b492b0 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -516,32 +516,53 @@ const ⋅ = dot const × = cross export ⋅, × -wrapper_char(::AbstractArray) = 'N' -wrapper_char(::Adjoint) = 'C' -wrapper_char(::Adjoint{<:Real}) = 'T' -wrapper_char(::Transpose) = 'T' -wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h' -wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's' -wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's' +# Separate the char corresponding to the wrapper from that corresponding to the uplo +# In most cases, the former may be constant-propagated, while the latter usually can't be. +# This improves type-inference in wrap for Symmetric/Hermitian matrices +struct WrapperChar <: AbstractChar + wrapperchar :: Char + isuppertri :: Bool +end +function Base.Char(w::WrapperChar) + T = w.wrapperchar + if T ∉ ('S', 'H') + T + else + _isuppertri(w) ? uppercase(T) : lowercase(T) + end +end +Base.codepoint(w::WrapperChar) = codepoint(Char(w)) +WrapperChar(n::UInt32) = WrapperChar(Char(n), true) +_isuppertri(w::WrapperChar) = w.isuppertri +_isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation +_getuplo(x) = _isuppertri(x) ? (:U) : (:L) + +wrapper_char(::AbstractArray) = WrapperChar('N') +wrapper_char(::Adjoint) = WrapperChar('C') +wrapper_char(::Adjoint{<:Real}) = WrapperChar('T') +wrapper_char(::Transpose) = WrapperChar('T') +wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U') +wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U') +wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U') + +_getwrapperchar(x) = x +_getwrapperchar(x::WrapperChar) = x.wrapperchar Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar) # merge the result of this before return, so that we can type-assert the return such # that even if the tmerge is inaccurate, inference can still identify that the # `_generic_matmatmul` signature still matches and doesn't require missing backedges - B = if tA == 'N' + tAwc = _getwrapperchar(tA) + B = if tAwc == 'N' A - elseif tA == 'T' + elseif tAwc == 'T' transpose(A) - elseif tA == 'C' + elseif tAwc == 'C' adjoint(A) - elseif tA == 'H' - Hermitian(A, :U) - elseif tA == 'h' - Hermitian(A, :L) - elseif tA == 'S' - Symmetric(A, :U) - else # tA == 's' - Symmetric(A, :L) + elseif tAwc ∈ ('H', 'h') + Hermitian(A, _getuplo(tA) #= unwrap a WrapperChar =#) + elseif tAwc ∈ ('S', 's') + Symmetric(A, _getuplo(tA) #= unwrap a WrapperChar =#) end return B::AbstractVecOrMat end From 307335234cd160a278cc9408bdea6c9344353d29 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 29 Apr 2024 17:29:27 +0530 Subject: [PATCH 02/15] Add inference tests --- stdlib/LinearAlgebra/test/matmul.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index c760f1adeffdd..3925151252913 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -30,6 +30,15 @@ mul_wrappers = [ h(A) = LinearAlgebra.wrap(LinearAlgebra._unwrap(A), LinearAlgebra.wrapper_char(A)) @test @inferred(h(transpose(A))) === transpose(A) @test @inferred(h(adjoint(A))) === transpose(A) + + M = rand(2,2) + for S in (Symmetric(M), Hermitian(M)) + @test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === Symmetric(M) + end + M = rand(ComplexF64,2,2) + for S in (Symmetric(M), Hermitian(M)) + @test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === S + end end @testset "matrices with zero dimensions" begin From 54a4038a951a423d2176763615cf022e1a70d29b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Apr 2024 17:15:11 +0530 Subject: [PATCH 03/15] LinearAlgbebra: constant propagate character in generic_matmatmul checks --- stdlib/LinearAlgebra/src/matmul.jl | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 9ed8bd1b677aa..12eba0fded20b 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -360,11 +360,21 @@ julia> lmul!(F.Q, B) """ lmul!(A, B) +# unroll the in(a, b) computation to enable constant propagation +# This is a 2-valued in implementation that doesn't account for missing values +_in(t::AbstractChar, ::Tuple{}) = false +function _in(t::AbstractChar, chars::Tuple{Vararg{AbstractChar}}) + return t == first(chars) || _in(t, Base.tail(chars)) +end +all_in(chars, (tA, tB)) = _in(tA, chars) && _in(tB, chars) + # THE one big BLAS dispatch # aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} - if all(in(('N', 'T', 'C')), (tA, tB)) + # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop + # The check is only on the wrapper type, so we may extract that from a WrapperChar + if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) if tA == 'T' && tB == 'N' && A === B return syrk_wrapper!(C, 'T', A, _add) elseif tA == 'N' && tB == 'T' && A === B @@ -395,7 +405,10 @@ end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} - if all(in(('N', 'T', 'C')), (tA, tB)) + special_cases = ('N', 'T', 'C') + # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop + # The check is only on the wrapper type, so we may extract that from a WrapperChar + if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B, _add) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) @@ -587,7 +600,9 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) C = similar(B, T, mA, nB) - if all(in(('N', 'T', 'C')), (tA, tB)) + # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop + # The check is only on the wrapper type, so we may extract that from a WrapperChar + if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) From 108db045d8f18d3604d5cc6312d9860dec98d49b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Apr 2024 18:53:23 +0530 Subject: [PATCH 04/15] Remove unused variable --- stdlib/LinearAlgebra/src/matmul.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 12eba0fded20b..20724f07d6f1f 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -405,7 +405,6 @@ end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} - special_cases = ('N', 'T', 'C') # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # The check is only on the wrapper type, so we may extract that from a WrapperChar if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) From a2dd315f0408beb6d8e3d6538e56b74e504c338c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Apr 2024 20:20:08 +0530 Subject: [PATCH 05/15] Use wrapperchar in checks --- stdlib/LinearAlgebra/src/matmul.jl | 42 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 20724f07d6f1f..f273a6b4b183e 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -376,13 +376,13 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, # The check is only on the wrapper type, so we may extract that from a WrapperChar if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) if tA == 'T' && tB == 'N' && A === B - return syrk_wrapper!(C, 'T', A, _add) + return syrk_wrapper!(C, oftype(tA, 'T'), A, _add) elseif tA == 'N' && tB == 'T' && A === B - return syrk_wrapper!(C, 'N', A, _add) + return syrk_wrapper!(C, oftype(tA, 'N'), A, _add) elseif tA == 'C' && tB == 'N' && A === B - return herk_wrapper!(C, 'C', A, _add) + return herk_wrapper!(C, oftype(tA, 'C'), A, _add) elseif tA == 'N' && tB == 'C' && A === B - return herk_wrapper!(C, 'N', A, _add) + return herk_wrapper!(C, oftype(tA, 'N'), A, _add) else return gemm_wrapper!(C, tA, tB, A, B, _add) end @@ -446,18 +446,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) + tA_ = _getwrapperchar(tA) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && !iszero(stride(x, 1)) && # We only check input's stride here. - if tA in ('N', 'T', 'C') + if _in(tA_, ('N', 'T', 'C')) return BLAS.gemv!(tA, alpha, A, x, beta, y) - elseif tA in ('S', 's') + elseif _in(tA_, ('S', 's')) return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) - elseif tA in ('H', 'h') + elseif _in(tA_, ('H', 'h')) return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) end end - if tA in ('S', 's', 'H', 'h') + if _in(tA_, ('S', 's', 'H', 'h')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) @@ -476,14 +477,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) + tA_ = _getwrapperchar(tA) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y` + stride(y, 1) == 1 && tA_ == 'N' && # reinterpret-based optimization is valid only for contiguous `y` !iszero(stride(x, 1)) BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + Anew, ta = _in(tA_, ('S', 's', 'H', 'h')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) end end @@ -499,18 +501,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) + tA_ = _getwrapperchar(tA) @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - !iszero(stride(x, 1)) && tA in ('N', 'T', 'C') + !iszero(stride(x, 1)) && _in(tA_, ('N', 'T', 'C')) xfl = reinterpret(reshape, T, x) # Use reshape here. yfl = reinterpret(reshape, T, y) BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y - elseif tA in ('S', 's', 'H', 'h') + elseif _in(tA, ('S', 's', 'H', 'h')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) + return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β)) else return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -717,9 +720,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref). """ function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - if tM == 'N' + tM_ = _getwrapperchar(tM) + if tM_ == 'N' copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src) - elseif tM == 'T' + elseif tM_ == 'T' copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src) else copy_adjoint!(B, ir_dest, jr_dest, M, jr_src, ir_src) @@ -748,11 +752,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref). """ function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - if tM == 'N' + tM_ = _getwrapperchar(tM) + if tM_ == 'N' copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src) else copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src) - tM == 'C' && conj!(@view B[ir_dest, jr_dest]) + tM_ == 'C' && conj!(@view B[ir_dest, jr_dest]) end B end @@ -765,7 +770,8 @@ end @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) - Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + tA_ = _getwrapperchar(tA) + Anew, ta = _in(tA_, ('S', 's', 'H', 'h')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(C, ta, Anew, B, _add) end From 5db39fbdb1387eb82570d2cd6c290b226b03961c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Apr 2024 20:55:45 +0530 Subject: [PATCH 06/15] Aggressive constprop annotation in gemm_wrapper --- stdlib/LinearAlgebra/src/matmul.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index f273a6b4b183e..234d07be250ef 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -596,7 +596,7 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA return gemm_wrapper!(C, tA, tAt, A, A, _add) end -function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, +Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) @@ -611,7 +611,7 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, end end -function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) @@ -651,7 +651,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end -function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, + Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasReal} mA, nA = lapack_size(tA, A) @@ -681,13 +681,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + tA_ = _getwrapperchar(tA) + # Make-sure reinterpret-based optimization is BLAS-compatible. if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1) && tA == 'N') + stride(C, 2) >= size(C, 1) && tA_ == 'N') BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end From 7a8bb752df3389ea73d32786c709738722960f91 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Apr 2024 22:57:32 +0530 Subject: [PATCH 07/15] uppercase(::WrapperChar) instead of accessor function --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 17 +++---- stdlib/LinearAlgebra/src/matmul.jl | 61 +++++++++++++---------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 6cb6f19b492b0..e64d5f315f0e2 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -533,6 +533,8 @@ function Base.Char(w::WrapperChar) end Base.codepoint(w::WrapperChar) = codepoint(Char(w)) WrapperChar(n::UInt32) = WrapperChar(Char(n), true) +Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar) +Base.lowercase(w::WrapperChar) = lowercase(w.wrapperchar) _isuppertri(w::WrapperChar) = w.isuppertri _isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation _getuplo(x) = _isuppertri(x) ? (:U) : (:L) @@ -545,23 +547,20 @@ wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U') wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U') wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U') -_getwrapperchar(x) = x -_getwrapperchar(x::WrapperChar) = x.wrapperchar - Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar) # merge the result of this before return, so that we can type-assert the return such # that even if the tmerge is inaccurate, inference can still identify that the # `_generic_matmatmul` signature still matches and doesn't require missing backedges - tAwc = _getwrapperchar(tA) - B = if tAwc == 'N' + tA_uc = uppercase(tA) + B = if tA_uc == 'N' A - elseif tAwc == 'T' + elseif tA_uc == 'T' transpose(A) - elseif tAwc == 'C' + elseif tA_uc == 'C' adjoint(A) - elseif tAwc ∈ ('H', 'h') + elseif tA_uc == 'H' Hermitian(A, _getuplo(tA) #= unwrap a WrapperChar =#) - elseif tAwc ∈ ('S', 's') + elseif tA_uc == 'S' Symmetric(A, _getuplo(tA) #= unwrap a WrapperChar =#) end return B::AbstractVecOrMat diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 234d07be250ef..6ddb3595e5511 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -374,7 +374,7 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # The check is only on the wrapper type, so we may extract that from a WrapperChar - if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) + if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) if tA == 'T' && tB == 'N' && A === B return syrk_wrapper!(C, oftype(tA, 'T'), A, _add) elseif tA == 'N' && tB == 'T' && A === B @@ -389,13 +389,13 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - if (tA == 'S' || tA == 's') && tB == 'N' + if uppercase(tA) == 'S' && tB == 'N' return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) - elseif (tB == 'S' || tB == 's') && tA == 'N' + elseif uppercase(tB) == 'S' && tA == 'N' return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) - elseif (tA == 'H' || tA == 'h') && tB == 'N' + elseif uppercase(tA) == 'H' && tB == 'N' return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) - elseif (tB == 'H' || tB == 'h') && tA == 'N' + elseif uppercase(tB) == 'H' && tA == 'N' return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) end end @@ -407,7 +407,7 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Compl _add::MulAddMul=MulAddMul()) where {T<:BlasReal} # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # The check is only on the wrapper type, so we may extract that from a WrapperChar - if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) + if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B, _add) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) @@ -446,22 +446,22 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) - tA_ = _getwrapperchar(tA) + tA_ = uppercase(tA) # potentially convert a WrapperChar to a Char if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && !iszero(stride(x, 1)) && # We only check input's stride here. if _in(tA_, ('N', 'T', 'C')) return BLAS.gemv!(tA, alpha, A, x, beta, y) - elseif _in(tA_, ('S', 's')) + elseif tA_ == 'S' return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) - elseif _in(tA_, ('H', 'h')) + elseif tA_ == 'H' return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) end end - if _in(tA_, ('S', 's', 'H', 'h')) + if _in(tA_, ('S', 'H')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) + return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β)) else return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -477,15 +477,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) - tA_ = _getwrapperchar(tA) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - stride(y, 1) == 1 && tA_ == 'N' && # reinterpret-based optimization is valid only for contiguous `y` + stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y` !iszero(stride(x, 1)) BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - Anew, ta = _in(tA_, ('S', 's', 'H', 'h')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) + Anew, ta = _in(tA_uc, ('S', 'H')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) end end @@ -501,16 +501,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) - tA_ = _getwrapperchar(tA) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - !iszero(stride(x, 1)) && _in(tA_, ('N', 'T', 'C')) + !iszero(stride(x, 1)) && _in(tA_uc, ('N', 'T', 'C')) xfl = reinterpret(reshape, T, x) # Use reshape here. yfl = reinterpret(reshape, T, y) BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y - elseif _in(tA, ('S', 's', 'H', 'h')) + elseif _in(tA_uc, ('S', 'H')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β)) @@ -562,10 +562,10 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA nC = checksquare(C) if tA == 'C' (nA, mA) = size(A,1), size(A,2) - tAt = 'N' + tAt = oftype(tA, 'N') else (mA, nA) = size(A,1), size(A,2) - tAt = 'C' + tAt = oftype(tA, 'C') end if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) @@ -596,6 +596,8 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA return gemm_wrapper!(C, tA, tAt, A, A, _add) end +# Aggressive constprop helps propagate the values of tA and tB into wrap, which +# makes the calls concretely inferred Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) where {T<:BlasFloat} @@ -604,13 +606,16 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract C = similar(B, T, mA, nB) # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # The check is only on the wrapper type, so we may extract that from a WrapperChar - if all_in(('N', 'T', 'C'), map(_getwrapperchar, (tA, tB))) + # map uppercase to potentially convert a WrapperChar to a Char + if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end end +# Aggressive constprop helps propagate the values of tA and tB into wrap, which +# makes the calls concretely inferred Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} @@ -651,7 +656,9 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end - Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, +# Aggressive constprop helps propagate the values of tA and tB into wrap, which +# makes the calls concretely inferred +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasReal} mA, nA = lapack_size(tA, A) @@ -681,7 +688,7 @@ end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) - tA_ = _getwrapperchar(tA) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char # Make-sure reinterpret-based optimization is BLAS-compatible. if (alpha isa Union{Bool,T} && @@ -689,7 +696,7 @@ end stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1) && tA_ == 'N') + stride(C, 2) >= size(C, 1) && tA_uc == 'N') BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end @@ -722,7 +729,7 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref). """ function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - tM_ = _getwrapperchar(tM) + tM_ = uppercase(tM) # potentially convert a WrapperChar to a Char if tM_ == 'N' copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src) elseif tM_ == 'T' @@ -754,7 +761,7 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref). """ function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - tM_ = _getwrapperchar(tM) + tM_ = uppercase(tM) # potentially convert a WrapperChar to a Char if tM_ == 'N' copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src) else @@ -772,8 +779,8 @@ end @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) - tA_ = _getwrapperchar(tA) - Anew, ta = _in(tA_, ('S', 's', 'H', 'h')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + Anew, ta = _in(tA_uc, ('S', 'H')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(C, ta, Anew, B, _add) end From 1e06a1e3a1e2cd6e9ef22db6f3a20534cfde7c03 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Apr 2024 23:01:46 +0530 Subject: [PATCH 08/15] Consistent variable name --- stdlib/LinearAlgebra/src/matmul.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 6ddb3595e5511..76b02319c6533 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -446,19 +446,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) - tA_ = uppercase(tA) # potentially convert a WrapperChar to a Char + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && !iszero(stride(x, 1)) && # We only check input's stride here. - if _in(tA_, ('N', 'T', 'C')) + if _in(tA_uc, ('N', 'T', 'C')) return BLAS.gemv!(tA, alpha, A, x, beta, y) - elseif tA_ == 'S' + elseif tA_uc == 'S' return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) - elseif tA_ == 'H' + elseif tA_uc == 'H' return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) end end - if _in(tA_, ('S', 'H')) + if _in(tA_uc, ('S', 'H')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β)) From cefa1b1ffc223a4a555194c64a3f4d65e27fd34c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 1 May 2024 12:38:54 +0530 Subject: [PATCH 09/15] Constprop in syrk_wrapper/herk_wrapper --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 11 ++-- stdlib/LinearAlgebra/src/matmul.jl | 65 +++++++++++++---------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index e64d5f315f0e2..edec97e3d0206 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -519,6 +519,7 @@ export ⋅, × # Separate the char corresponding to the wrapper from that corresponding to the uplo # In most cases, the former may be constant-propagated, while the latter usually can't be. # This improves type-inference in wrap for Symmetric/Hermitian matrices +# A WrapperChar is equivalent to `isuppertri ? uppercase(wrapperchar) : lowercase(wrapperchar)` struct WrapperChar <: AbstractChar wrapperchar :: Char isuppertri :: Bool @@ -533,16 +534,18 @@ function Base.Char(w::WrapperChar) end Base.codepoint(w::WrapperChar) = codepoint(Char(w)) WrapperChar(n::UInt32) = WrapperChar(Char(n), true) +# We extract the wrapperchar so that the result may be constant-propagated +# This doesn't return a value of the same type on purpose Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar) Base.lowercase(w::WrapperChar) = lowercase(w.wrapperchar) _isuppertri(w::WrapperChar) = w.isuppertri _isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation _getuplo(x) = _isuppertri(x) ? (:U) : (:L) -wrapper_char(::AbstractArray) = WrapperChar('N') -wrapper_char(::Adjoint) = WrapperChar('C') -wrapper_char(::Adjoint{<:Real}) = WrapperChar('T') -wrapper_char(::Transpose) = WrapperChar('T') +wrapper_char(::AbstractArray) = 'N' +wrapper_char(::Adjoint) = 'C' +wrapper_char(::Adjoint{<:Real}) = 'T' +wrapper_char(::Transpose) = 'T' wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U') wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U') wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U') diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 76b02319c6533..bb4b2e78cc3fc 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -373,29 +373,31 @@ all_in(chars, (tA, tB)) = _in(tA, chars) && _in(tB, chars) Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop - # The check is only on the wrapper type, so we may extract that from a WrapperChar - if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) - if tA == 'T' && tB == 'N' && A === B - return syrk_wrapper!(C, oftype(tA, 'T'), A, _add) - elseif tA == 'N' && tB == 'T' && A === B - return syrk_wrapper!(C, oftype(tA, 'N'), A, _add) - elseif tA == 'C' && tB == 'N' && A === B - return herk_wrapper!(C, oftype(tA, 'C'), A, _add) - elseif tA == 'N' && tB == 'C' && A === B - return herk_wrapper!(C, oftype(tA, 'N'), A, _add) + # We convert the chars to uppercase to potentially unwrap a WraperChar, + # and extract the char corresponding to the wrapper type + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + if all_in(('N', 'T', 'C'), map(uppercase, (tA_uc, tB_uc))) + if tA_uc == 'T' && tB_uc == 'N' && A === B + return syrk_wrapper!(C, 'T', A, _add) + elseif tA_uc == 'N' && tB_uc == 'T' && A === B + return syrk_wrapper!(C, 'N', A, _add) + elseif tA_uc == 'C' && tB_uc == 'N' && A === B + return herk_wrapper!(C, 'C', A, _add) + elseif tA_uc == 'N' && tB_uc == 'C' && A === B + return herk_wrapper!(C, 'N', A, _add) else return gemm_wrapper!(C, tA, tB, A, B, _add) end end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - if uppercase(tA) == 'S' && tB == 'N' + if tA_uc == 'S' && tB_uc == 'N' return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) - elseif uppercase(tB) == 'S' && tA == 'N' + elseif tB_uc == 'S' && tA_uc == 'N' return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) - elseif uppercase(tA) == 'H' && tB == 'N' + elseif tA_uc == 'H' && tB_uc == 'N' return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) - elseif uppercase(tB) == 'H' && tA == 'N' + elseif tB_uc == 'H' && tA_uc == 'N' return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) end end @@ -406,7 +408,8 @@ end Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop - # The check is only on the wrapper type, so we may extract that from a WrapperChar + # We convert the chars to uppercase to potentially unwrap a WraperChar, + # and extract the char corresponding to the wrapper type if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B, _add) else @@ -519,15 +522,17 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs end end -function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, +# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it +Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} nC = checksquare(C) - if tA == 'T' + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + if tA_uc == 'T' (nA, mA) = size(A,1), size(A,2) - tAt = 'N' + tAt = oftype(tA, 'N') else (mA, nA) = size(A,1), size(A,2) - tAt = 'T' + tAt = oftype(tA, 'T') end if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) @@ -557,10 +562,12 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat return gemm_wrapper!(C, tA, tAt, A, A, _add) end -function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, +# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it +Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, _add = MulAddMul()) where {T<:BlasReal} nC = checksquare(C) - if tA == 'C' + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + if tA_uc == 'C' (nA, mA) = size(A,1), size(A,2) tAt = oftype(tA, 'N') else @@ -605,8 +612,8 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract mB, nB = lapack_size(tB, B) C = similar(B, T, mA, nB) # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop - # The check is only on the wrapper type, so we may extract that from a WrapperChar - # map uppercase to potentially convert a WrapperChar to a Char + # We convert the chars to uppercase to potentially unwrap a WraperChar, + # and extract the char corresponding to the wrapper type if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B) else @@ -729,10 +736,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref). """ function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - tM_ = uppercase(tM) # potentially convert a WrapperChar to a Char - if tM_ == 'N' + tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char + if tM_uc == 'N' copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src) - elseif tM_ == 'T' + elseif tM_uc == 'T' copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src) else copy_adjoint!(B, ir_dest, jr_dest, M, jr_src, ir_src) @@ -761,12 +768,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref). """ function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - tM_ = uppercase(tM) # potentially convert a WrapperChar to a Char - if tM_ == 'N' + tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char + if tM_uc == 'N' copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src) else copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src) - tM_ == 'C' && conj!(@view B[ir_dest, jr_dest]) + tM_uc == 'C' && conj!(@view B[ir_dest, jr_dest]) end B end From 28d4fd85afdbd2d41547563ce9c65594e32bc4e4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 1 May 2024 12:44:35 +0530 Subject: [PATCH 10/15] Fix typo --- stdlib/LinearAlgebra/src/matmul.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index bb4b2e78cc3fc..a4df8b32dcd7c 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -373,7 +373,7 @@ all_in(chars, (tA, tB)) = _in(tA, chars) && _in(tB, chars) Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop - # We convert the chars to uppercase to potentially unwrap a WraperChar, + # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type tA_uc, tB_uc = uppercase(tA), uppercase(tB) if all_in(('N', 'T', 'C'), map(uppercase, (tA_uc, tB_uc))) @@ -408,7 +408,7 @@ end Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop - # We convert the chars to uppercase to potentially unwrap a WraperChar, + # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B, _add) @@ -612,7 +612,7 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract mB, nB = lapack_size(tB, B) C = similar(B, T, mA, nB) # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop - # We convert the chars to uppercase to potentially unwrap a WraperChar, + # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) gemm_wrapper!(C, tA, tB, A, B) From b24cc1ea5bbc1ce2b67dc3369a9a19023b48cad1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 1 May 2024 12:47:48 +0530 Subject: [PATCH 11/15] Update comment --- stdlib/LinearAlgebra/src/matmul.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index a4df8b32dcd7c..10dbb70aa675c 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -523,6 +523,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs end # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it +# to be concretely inferred Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} nC = checksquare(C) @@ -563,6 +564,7 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst end # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it +# to be concretely inferred Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, _add = MulAddMul()) where {T<:BlasReal} nC = checksquare(C) From 9e66d1a03dbe7743b8ecf1c7aa65229fd9c2e85c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 1 May 2024 22:25:27 +0530 Subject: [PATCH 12/15] Remove some unnecessary type conversions --- stdlib/LinearAlgebra/src/matmul.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 10dbb70aa675c..4edb915cc7e8b 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -464,7 +464,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar if _in(tA_uc, ('S', 'H')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β)) + return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) else return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -516,7 +516,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs elseif _in(tA_uc, ('S', 'H')) # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return _generic_matvecmul!(y, oftype(tA, 'N'), wrap(A, tA), x, MulAddMul(α, β)) + return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) else return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end @@ -530,10 +530,10 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'T' (nA, mA) = size(A,1), size(A,2) - tAt = oftype(tA, 'N') + tAt = 'N' else (mA, nA) = size(A,1), size(A,2) - tAt = oftype(tA, 'T') + tAt = 'T' end if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) @@ -571,10 +571,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'C' (nA, mA) = size(A,1), size(A,2) - tAt = oftype(tA, 'N') + tAt = 'N' else (mA, nA) = size(A,1), size(A,2) - tAt = oftype(tA, 'C') + tAt = 'C' end if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) From 541a76d6f24e943c98869a73559eedaf11bac9d6 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 3 May 2024 00:43:09 +0530 Subject: [PATCH 13/15] Use all(map(...)) instead of all_in --- stdlib/LinearAlgebra/src/matmul.jl | 34 ++++++++++++------------------ 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 4edb915cc7e8b..c8b31aae1e26b 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -360,23 +360,15 @@ julia> lmul!(F.Q, B) """ lmul!(A, B) -# unroll the in(a, b) computation to enable constant propagation -# This is a 2-valued in implementation that doesn't account for missing values -_in(t::AbstractChar, ::Tuple{}) = false -function _in(t::AbstractChar, chars::Tuple{Vararg{AbstractChar}}) - return t == first(chars) || _in(t, Base.tail(chars)) -end -all_in(chars, (tA, tB)) = _in(tA, chars) && _in(tB, chars) - # THE one big BLAS dispatch # aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} - # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type tA_uc, tB_uc = uppercase(tA), uppercase(tB) - if all_in(('N', 'T', 'C'), map(uppercase, (tA_uc, tB_uc))) + # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. + if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) if tA_uc == 'T' && tB_uc == 'N' && A === B return syrk_wrapper!(C, 'T', A, _add) elseif tA_uc == 'N' && tB_uc == 'T' && A === B @@ -407,10 +399,11 @@ end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} - # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type - if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. + if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) gemm_wrapper!(C, tA, tB, A, B, _add) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) @@ -453,7 +446,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && !iszero(stride(x, 1)) && # We only check input's stride here. - if _in(tA_uc, ('N', 'T', 'C')) + if tA_uc in ('N', 'T', 'C') return BLAS.gemv!(tA, alpha, A, x, beta, y) elseif tA_uc == 'S' return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) @@ -461,7 +454,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) end end - if _in(tA_uc, ('S', 'H')) + if tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) @@ -488,7 +481,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - Anew, ta = _in(tA_uc, ('S', 'H')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) + Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) end end @@ -507,13 +500,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - !iszero(stride(x, 1)) && _in(tA_uc, ('N', 'T', 'C')) + !iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C') xfl = reinterpret(reshape, T, x) # Use reshape here. yfl = reinterpret(reshape, T, y) BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y - elseif _in(tA_uc, ('S', 'H')) + elseif tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) @@ -613,10 +606,11 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) C = similar(B, T, mA, nB) - # if all(in(('N', 'T', 'C')), (tA, tB)), but we unroll the implementation to enable constprop # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type - if all_in(('N', 'T', 'C'), map(uppercase, (tA, tB))) + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. + if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) gemm_wrapper!(C, tA, tB, A, B) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) @@ -789,7 +783,7 @@ end @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char - Anew, ta = _in(tA_uc, ('S', 'H')) ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) + Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(C, ta, Anew, B, _add) end From 31692d5a59ac6eace977fcdff0d7f589e1d34e3e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 3 May 2024 13:02:07 +0530 Subject: [PATCH 14/15] WrapperChar Constructor from Char --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index edec97e3d0206..8b3cbc9501e41 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -533,7 +533,8 @@ function Base.Char(w::WrapperChar) end end Base.codepoint(w::WrapperChar) = codepoint(Char(w)) -WrapperChar(n::UInt32) = WrapperChar(Char(n), true) +WrapperChar(n::UInt32) = WrapperChar(Char(n)) +WrapperChar(c::Char) = WrapperChar(c, true) # this constructor helps with assuming :nothrow # We extract the wrapperchar so that the result may be constant-propagated # This doesn't return a value of the same type on purpose Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar) From 311382b92304f77e76790c1556236514e7418124 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 6 May 2024 20:11:10 +0530 Subject: [PATCH 15/15] Preserve case in WrapperChar constructor --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 10 +++++----- stdlib/LinearAlgebra/src/matmul.jl | 4 ++-- stdlib/LinearAlgebra/test/matmul.jl | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 8b3cbc9501e41..b8412fc361d3f 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -526,7 +526,7 @@ struct WrapperChar <: AbstractChar end function Base.Char(w::WrapperChar) T = w.wrapperchar - if T ∉ ('S', 'H') + if T ∈ ('N', 'T', 'C') # known cases where isuppertri is true T else _isuppertri(w) ? uppercase(T) : lowercase(T) @@ -534,14 +534,14 @@ function Base.Char(w::WrapperChar) end Base.codepoint(w::WrapperChar) = codepoint(Char(w)) WrapperChar(n::UInt32) = WrapperChar(Char(n)) -WrapperChar(c::Char) = WrapperChar(c, true) # this constructor helps with assuming :nothrow +WrapperChar(c::Char) = WrapperChar(c, isuppercase(c)) # We extract the wrapperchar so that the result may be constant-propagated # This doesn't return a value of the same type on purpose Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar) Base.lowercase(w::WrapperChar) = lowercase(w.wrapperchar) _isuppertri(w::WrapperChar) = w.isuppertri _isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation -_getuplo(x) = _isuppertri(x) ? (:U) : (:L) +_uplosym(x) = _isuppertri(x) ? (:U) : (:L) wrapper_char(::AbstractArray) = 'N' wrapper_char(::Adjoint) = 'C' @@ -563,9 +563,9 @@ Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar) elseif tA_uc == 'C' adjoint(A) elseif tA_uc == 'H' - Hermitian(A, _getuplo(tA) #= unwrap a WrapperChar =#) + Hermitian(A, _uplosym(tA)) elseif tA_uc == 'S' - Symmetric(A, _getuplo(tA) #= unwrap a WrapperChar =#) + Symmetric(A, _uplosym(tA)) end return B::AbstractVecOrMat end diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index c8b31aae1e26b..9c74addd6b69c 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -385,11 +385,11 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, if alpha isa Union{Bool,T} && beta isa Union{Bool,T} if tA_uc == 'S' && tB_uc == 'N' return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) - elseif tB_uc == 'S' && tA_uc == 'N' + elseif tA_uc == 'N' && tB_uc == 'S' return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) elseif tA_uc == 'H' && tB_uc == 'N' return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) - elseif tB_uc == 'H' && tA_uc == 'N' + elseif tA_uc == 'N' && tB_uc == 'H' return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) end end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 3925151252913..db61fbe0ab45a 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -39,6 +39,21 @@ mul_wrappers = [ for S in (Symmetric(M), Hermitian(M)) @test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === S end + + @testset "WrapperChar" begin + @test LinearAlgebra.WrapperChar('c') == 'c' + @test LinearAlgebra.WrapperChar('C') == 'C' + @testset "constant propagation in uppercase/lowercase" begin + v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('C'))))() + @test v isa Val{'C'} + v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('s'))))() + @test v isa Val{'S'} + v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('C'))))() + @test v isa Val{'c'} + v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('s'))))() + @test v isa Val{'s'} + end + end end @testset "matrices with zero dimensions" begin