From af0d7d4ae056de8f629b926cd607e4e39e7820b8 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sun, 5 Nov 2023 22:37:23 +0100 Subject: [PATCH 01/16] Reduce compile time for generic matmatmul --- stdlib/LinearAlgebra/src/matmul.jl | 229 +++++----------------------- stdlib/LinearAlgebra/test/matmul.jl | 4 +- 2 files changed, 38 insertions(+), 195 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 018ad20e538c8..2c36e1cc62bb4 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -337,7 +337,7 @@ julia> lmul!(F.Q, B) lmul!(A, B) # THE one big BLAS dispatch -@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, +Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} if all(in(('N', 'T', 'C')), (tA, tB)) if tA == 'T' && tB == 'N' && A === B @@ -364,16 +364,16 @@ lmul!(A, B) return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) end end - return _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. -@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, +Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} if all(in(('N', 'T', 'C')), (tA, tB)) gemm_wrapper!(C, tA, tB, A, B, _add) else - _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end end @@ -563,11 +563,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, if all(in(('N', 'T', 'C')), (tA, tB)) gemm_wrapper!(C, tA, tB, A, B) else - _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end end -function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) @@ -604,10 +604,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar stride(C, 2) >= size(C, 1)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - _generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end -function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasReal} mA, nA = lapack_size(tA, A) @@ -647,7 +647,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - _generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end # blas.jl defines matmul for floats; other integer and mixed precision @@ -764,197 +764,40 @@ end const tilebufsize = 10800 # Approximately 32k/3 -function generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) - mA, nA = lapack_size(tA, A) - mB, nB = lapack_size(tB, B) - mC, nC = size(C) - - if iszero(_add.alpha) - return _rmul_or_fill!(C, _add.beta) - end - if mA == nA == mB == nB == mC == nC == 2 - return matmul2x2!(C, tA, tB, A, B, _add) - end - if mA == nA == mB == nB == mC == nC == 3 - return matmul3x3!(C, tA, tB, A, B, _add) - end - A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) - B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) - _generic_matmatmul!(C, tA, tB, A, B, _add) -end +Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) = + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) -function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, +@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, _add::MulAddMul) where {T,S,R} - @assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') - require_one_based_indexing(C, A, B) - - mA, nA = lapack_size(tA, A) - mB, nB = lapack_size(tB, B) - if mB != nA - throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)")) + AxM = axes(A, 1) + AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector` + BxK = axes(B, 1) + BxN = axes(B, 2) + CxM = axes(C, 1) + CxN = axes(C, 2) + if AxM != CxM + throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)")) end - if size(C,1) != mA || size(C,2) != nB - throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)")) + if AxK != BxK + throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)")) + end + if BxN != CxN + throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end - if iszero(_add.alpha) || isempty(A) || isempty(B) return _rmul_or_fill!(C, _add.beta) end - - tile_size = 0 - if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N') - tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1))) - end - @inbounds begin - if tile_size > 0 - sz = (tile_size, tile_size) - Atile = Array{T}(undef, sz) - Btile = Array{S}(undef, sz) - - z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1]) - z = convert(promote_type(typeof(z1), R), z1) - - if mA < tile_size && nA < tile_size && nB < tile_size - copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA) - copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB) - for j = 1:nB - boff = (j-1)*tile_size - for i = 1:mA - aoff = (i-1)*tile_size - s = z - for k = 1:nA - s += Atile[aoff+k] * Btile[boff+k] - end - _modify!(_add, s, C, (i,j)) - end - end - else - Ctile = Array{R}(undef, sz) - for jb = 1:tile_size:nB - jlim = min(jb+tile_size-1,nB) - jlen = jlim-jb+1 - for ib = 1:tile_size:mA - ilim = min(ib+tile_size-1,mA) - ilen = ilim-ib+1 - fill!(Ctile, z) - for kb = 1:tile_size:nA - klim = min(kb+tile_size-1,mB) - klen = klim-kb+1 - copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim) - copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim) - for j=1:jlen - bcoff = (j-1)*tile_size - for i = 1:ilen - aoff = (i-1)*tile_size - s = z - for k = 1:klen - s += Atile[aoff+k] * Btile[bcoff+k] - end - Ctile[bcoff+i] += s - end - end - end - if isone(_add.alpha) && iszero(_add.beta) - copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen) - else - C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim]) - end - end - end - end - else - # Multiplication for non-plain-data uses the naive algorithm - if tA == 'N' - if tB == 'N' - for i = 1:mA, j = 1:nB - z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += A[i, k]*B[k, j] - end - _modify!(_add, Ctmp, C, (i,j)) - end - elseif tB == 'T' - for i = 1:mA, j = 1:nB - z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1])) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += A[i, k] * transpose(B[j, k]) - end - _modify!(_add, Ctmp, C, (i,j)) - end - else - for i = 1:mA, j = 1:nB - z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]') - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += A[i, k]*B[j, k]' - end - _modify!(_add, Ctmp, C, (i,j)) - end - end - elseif tA == 'T' - if tB == 'N' - for i = 1:mA, j = 1:nB - z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j]) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += transpose(A[k, i]) * B[k, j] - end - _modify!(_add, Ctmp, C, (i,j)) - end - elseif tB == 'T' - for i = 1:mA, j = 1:nB - z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1])) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += transpose(A[k, i]) * transpose(B[j, k]) - end - _modify!(_add, Ctmp, C, (i,j)) - end - else - for i = 1:mA, j = 1:nB - z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]') - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += transpose(A[k, i]) * adjoint(B[j, k]) - end - _modify!(_add, Ctmp, C, (i,j)) - end - end - else - if tB == 'N' - for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j]) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += A[k, i]'B[k, j] - end - _modify!(_add, Ctmp, C, (i,j)) - end - elseif tB == 'T' - for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1])) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += adjoint(A[k, i]) * transpose(B[j, k]) - end - _modify!(_add, Ctmp, C, (i,j)) - end - else - for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]') - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k = 1:nA - Ctmp += A[k, i]'B[j, k]' - end - _modify!(_add, Ctmp, C, (i,j)) - end - end + a1 = first(AxK) + b1 = first(BxK) + @inbounds for i in AxM, j in BxN + z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k in AxK + Ctmp += A[i, k]*B[k, j] end + _modify!(_add, Ctmp, C, (i,j)) end - end # @inbounds - C + return C end @@ -963,7 +806,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T, matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B) end -function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, +Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) if !(size(A) == size(B) == size(C) == (2,2)) @@ -1030,7 +873,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T, matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B) end -function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, +Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) if !(size(A) == size(B) == size(C) == (3,3)) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 86606654e911a..c982e2a6be835 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -227,7 +227,7 @@ end @test C == AB mul!(C, A, B, 2, -1) @test C == AB - LinearAlgebra._generic_matmatmul!(C, 'N', 'N', A, B, LinearAlgebra.MulAddMul(2, -1)) + LinearAlgebra.generic_matmatmul!(C, 'N', 'N', A, B, LinearAlgebra.MulAddMul(2, -1)) @test C == AB end @@ -871,7 +871,7 @@ end # Just in case dispatching on the surface API `mul!` is changed in the future, # let's test the function where the tiled multiplication is defined. fill!(C, 0) - LinearAlgebra._generic_matmatmul!(C, 'N', 'N', A, B, LinearAlgebra.MulAddMul(-1, 0)) + LinearAlgebra.generic_matmatmul!(C, 'N', 'N', A, B, LinearAlgebra.MulAddMul(-1, 0)) @test D ≈ C end From d50fbc53a37c2775b15c8bb012d5443c1cb7d459 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 6 Nov 2023 08:33:14 +0100 Subject: [PATCH 02/16] Update stdlib/LinearAlgebra/src/matmul.jl Co-authored-by: Chris Elrod --- stdlib/LinearAlgebra/src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 2c36e1cc62bb4..558b529297c98 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -793,7 +793,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) Ctmp = convert(promote_type(R, typeof(z2)), z2) for k in AxK - Ctmp += A[i, k]*B[k, j] + Ctmp = muladd(A[i, k], B[k, j], Ctmp) end _modify!(_add, Ctmp, C, (i,j)) end From ef4b837ef1587ca4a4843af0b992abba9d957afe Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 6 Nov 2023 19:46:24 +0100 Subject: [PATCH 03/16] reduce repeated transposition --- stdlib/LinearAlgebra/src/adjtrans.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index c8b02f7a10a76..a6fb924ace60f 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -281,6 +281,8 @@ adjoint(A::Adjoint) = A.parent transpose(A::Transpose) = A.parent adjoint(A::Transpose{<:Real}) = A.parent transpose(A::Adjoint{<:Real}) = A.parent +adjoint(A::Transpose{<:Any,<:Adjoint}) = transpose(A.parent.parent) +transpose(A::Adjoint{<:Any,<:Transpose}) = adjoint(A.parent.parent) # printing function Base.showarg(io::IO, v::Adjoint, toplevel) From 19d78b636b664a66c90e3850e232fe701bd3a2fd Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 6 Nov 2023 20:37:21 +0100 Subject: [PATCH 04/16] fix some tests --- stdlib/LinearAlgebra/src/adjtrans.jl | 3 +++ test/testhelpers/Furlongs.jl | 1 + 2 files changed, 4 insertions(+) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index a6fb924ace60f..9d3441db77253 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -283,6 +283,9 @@ adjoint(A::Transpose{<:Real}) = A.parent transpose(A::Adjoint{<:Real}) = A.parent adjoint(A::Transpose{<:Any,<:Adjoint}) = transpose(A.parent.parent) transpose(A::Adjoint{<:Any,<:Transpose}) = adjoint(A.parent.parent) +# disambiguation +adjoint(A::Transpose{<:Real,<:Adjoint}) = transpose(A.parent.parent) +transpose(A::Adjoint{<:Real,<:Transpose}) = adjoint(A.parent.parent) # printing function Base.showarg(io::IO, v::Adjoint, toplevel) diff --git a/test/testhelpers/Furlongs.jl b/test/testhelpers/Furlongs.jl index f63b5460c7c16..6d52260bb20fd 100644 --- a/test/testhelpers/Furlongs.jl +++ b/test/testhelpers/Furlongs.jl @@ -99,5 +99,6 @@ for op in (:rem, :mod) end end Base.sqrt(x::Furlong) = _div(sqrt(x.val), x, Val(2)) +Base.muladd(x::Furlong, y::Furlong, z::Furlong) = x*y + z end From c8e9a0bb1caffd8ffcb00cf40c38041052e345ed Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 7 Nov 2023 10:36:27 +0100 Subject: [PATCH 05/16] fix tests --- stdlib/LinearAlgebra/test/matmul.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index c982e2a6be835..c3070e0c75330 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -692,8 +692,8 @@ end function test_mul(C, A, B) mul!(C, A, B) - @test Array(A) * Array(B) ≈ C - @test A * B ≈ C + @test Array{Float64}(A) * Array{Float64}(B) ≈ C + @test Float64.(A) * Float64.(B) ≈ C # This is similar to how `isapprox` choose `rtol` (when `atol=0`) # but consider all number types involved: @@ -706,8 +706,8 @@ function test_mul(C, A, B) βArrayC = β * Array(C) βC = β * C mul!(C, A, B, α, β) - @test α * Array(A) * Array(B) .+ βArrayC ≈ C rtol = rtol - @test α * A * B .+ βC ≈ C rtol = rtol + @test α * Float64.(Array(A)) * Float64.(Array(B)) .+ βArrayC ≈ C rtol = rtol + @test α * Float64.(A) * Float64.(B) .+ βC ≈ C rtol = rtol end @testset "mul! vs * for special types" begin From 5d5592abc2b39f2a2f2365b2d59243b4a702677a Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 7 Nov 2023 10:50:32 +0100 Subject: [PATCH 06/16] minor detail --- stdlib/LinearAlgebra/src/adjtrans.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 9d3441db77253..346ed29785a36 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -285,7 +285,7 @@ adjoint(A::Transpose{<:Any,<:Adjoint}) = transpose(A.parent.parent) transpose(A::Adjoint{<:Any,<:Transpose}) = adjoint(A.parent.parent) # disambiguation adjoint(A::Transpose{<:Real,<:Adjoint}) = transpose(A.parent.parent) -transpose(A::Adjoint{<:Real,<:Transpose}) = adjoint(A.parent.parent) +transpose(A::Adjoint{<:Real,<:Transpose}) = A.parent # printing function Base.showarg(io::IO, v::Adjoint, toplevel) From 416fa1cba726c104070d0828ee93ec0c4cd1c2b1 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 7 Nov 2023 12:27:34 +0100 Subject: [PATCH 07/16] avoid overflow in integer mul --- stdlib/LinearAlgebra/test/matmul.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index c3070e0c75330..af3709304bc87 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -690,28 +690,28 @@ Transpose(x::RootInt) = x @test A * a == [56] end -function test_mul(C, A, B) +function test_mul(C, A, B, S) mul!(C, A, B) - @test Array{Float64}(A) * Array{Float64}(B) ≈ C - @test Float64.(A) * Float64.(B) ≈ C + @test Array(A) * Array(B) ≈ C + @test A * B ≈ C # This is similar to how `isapprox` choose `rtol` (when `atol=0`) # but consider all number types involved: rtol = max(rtoldefault.(real.(eltype.((C, A, B))))...) - rand!(C) + rand!(C, S) T = promote_type(eltype.((A, B))...) - α = rand(T) - β = rand(T) + α = T <: AbstractFloat ? rand(T) : rand(T(-10):T(10)) + β = T <: AbstractFloat ? rand(T) : rand(T(-10):T(10)) βArrayC = β * Array(C) βC = β * C mul!(C, A, B, α, β) - @test α * Float64.(Array(A)) * Float64.(Array(B)) .+ βArrayC ≈ C rtol = rtol - @test α * Float64.(A) * Float64.(B) .+ βC ≈ C rtol = rtol + @test α * Array(A) * Array(B) .+ βArrayC ≈ C rtol = rtol + @test α * A * B .+ βC ≈ C rtol = rtol end @testset "mul! vs * for special types" begin - eltypes = [Float32, Float64, Int64] + eltypes = [Float32, Float64, Int64(-100):Int64(100)] for k in [3, 4, 10] T = rand(eltypes) bi1 = Bidiagonal(rand(T, k), rand(T, k - 1), rand([:U, :L])) @@ -724,26 +724,26 @@ end specialmatrices = (bi1, bi2, tri1, tri2, stri1, stri2) for A in specialmatrices B = specialmatrices[rand(1:length(specialmatrices))] - test_mul(C, A, B) + test_mul(C, A, B, T) end for S in specialmatrices l = rand(1:6) B = randn(k, l) C = randn(k, l) - test_mul(C, S, B) + test_mul(C, S, B, T) A = randn(l, k) C = randn(l, k) - test_mul(C, A, S) + test_mul(C, A, S, T) end end for T in eltypes A = Bidiagonal(rand(T, 2), rand(T, 1), rand([:U, :L])) B = Bidiagonal(rand(T, 2), rand(T, 1), rand([:U, :L])) C = randn(2, 2) - test_mul(C, A, B) + test_mul(C, A, B, T) B = randn(2, 9) C = randn(2, 9) - test_mul(C, A, B) + test_mul(C, A, B, T) end let tri44 = Tridiagonal(randn(3), randn(4), randn(3)) From df79cdb6c8d56033ae7bed7420ab06e771d24268 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 Nov 2023 10:54:30 +0100 Subject: [PATCH 08/16] branch over sizeof target eltype --- stdlib/LinearAlgebra/src/matmul.jl | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 558b529297c98..3a799149f67dc 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -784,18 +784,26 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end - if iszero(_add.alpha) || isempty(A) || isempty(B) - return _rmul_or_fill!(C, _add.beta) - end - a1 = first(AxK) - b1 = first(BxK) - @inbounds for i in AxM, j in BxN - z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) - Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k in AxK - Ctmp = muladd(A[i, k], B[k, j], Ctmp) + if sizeof(R) ≤ 16 + _rmul_or_fill!(C, _add.beta) + (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C + @inbounds for n in BxN, k in BxK, m in AxM + C[m,n] = muladd(A[m,k], B[k,n]*_add.alpha, C[m,n]) + end + else + if iszero(_add.alpha) || isempty(A) || isempty(B) + return _rmul_or_fill!(C, _add.beta) + end + a1 = first(AxK) + b1 = first(BxK) + @inbounds for i in AxM, j in BxN + z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k in AxK + Ctmp = muladd(A[i, k], B[k, j], Ctmp) + end + _modify!(_add, Ctmp, C, (i,j)) end - _modify!(_add, Ctmp, C, (i,j)) end return C end From c6ec74a0693661911c84f9c093184a1a924c4c23 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 Nov 2023 12:01:55 +0100 Subject: [PATCH 09/16] hoist out one factor --- stdlib/LinearAlgebra/src/matmul.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 3a799149f67dc..0d30e45b4661b 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -787,8 +787,11 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if sizeof(R) ≤ 16 _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C - @inbounds for n in BxN, k in BxK, m in AxM - C[m,n] = muladd(A[m,k], B[k,n]*_add.alpha, C[m,n]) + @inbounds for n in BxN, k in BxK + Balpha = B[k,n]*_add.alpha + for m in AxM + C[m,n] = muladd(A[m,k], Balpha, C[m,n]) + end end else if iszero(_add.alpha) || isempty(A) || isempty(B) From 9467795f6369b5b45588d6d4c85e88cba4f1f944 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 Nov 2023 12:58:24 +0100 Subject: [PATCH 10/16] optimizations --- stdlib/LinearAlgebra/src/matmul.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 0d30e45b4661b..7cc15a5f31451 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -784,12 +784,12 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end - if sizeof(R) ≤ 16 + if sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in BxN, k in BxK Balpha = B[k,n]*_add.alpha - for m in AxM + @simd for m in AxM C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end end @@ -802,7 +802,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A @inbounds for i in AxM, j in BxN z2 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j]) Ctmp = convert(promote_type(R, typeof(z2)), z2) - for k in AxK + @simd for k in AxK Ctmp = muladd(A[i, k], B[k, j], Ctmp) end _modify!(_add, Ctmp, C, (i,j)) From 5a30f30aba33fd5d6982bc7b1566504209b927b5 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 Nov 2023 21:00:27 +0100 Subject: [PATCH 11/16] optimize double adjortrans case --- stdlib/LinearAlgebra/src/matmul.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 7cc15a5f31451..4a462e66f8cbb 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -793,6 +793,22 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end end + elseif sizeof(R) ≤ 16 && (A isa Adjoint && B isa Adjoint) || (A isa Transpose & B isa Transpose) + _rmul_or_fill!(C, _add.beta) + (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C + t = wrapperop(A) + _rmul_or_fill!(C, _add.beta) + (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C + pB = parent(B) + pA = parent(A) + tmp = similar(C, CxN) + ci = first(CxM) + ta = t(_add.alpha) + for i in AxM + mul!(tmp, pB, view(pA, :, i)) + C[ci,:] .+= t.(ta .* tmp) + ci += 1 + end else if iszero(_add.alpha) || isempty(A) || isempty(B) return _rmul_or_fill!(C, _add.beta) From c100f20661dcd63c7e0ba48427c1a6265788a076 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 Nov 2023 22:33:14 +0100 Subject: [PATCH 12/16] fix typo --- stdlib/LinearAlgebra/src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 4a462e66f8cbb..26b63787910ef 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -793,7 +793,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end end - elseif sizeof(R) ≤ 16 && (A isa Adjoint && B isa Adjoint) || (A isa Transpose & B isa Transpose) + elseif sizeof(R) ≤ 16 && (A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C t = wrapperop(A) From 59f8b7084fa23f2942a39dd4fbae93303e019769 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 9 Nov 2023 14:45:39 +0100 Subject: [PATCH 13/16] fix typos --- stdlib/LinearAlgebra/src/matmul.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 26b63787910ef..67275370b2a4d 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -793,12 +793,10 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end end - elseif sizeof(R) ≤ 16 && (A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose) + elseif sizeof(R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose)) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C t = wrapperop(A) - _rmul_or_fill!(C, _add.beta) - (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C pB = parent(B) pA = parent(A) tmp = similar(C, CxN) From 4c09e5203806f711de53cb984bcc0ea78584fa3b Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 9 Nov 2023 23:31:13 +0100 Subject: [PATCH 14/16] check isbitstype, fix some tests --- stdlib/LinearAlgebra/src/matmul.jl | 4 ++-- stdlib/LinearAlgebra/test/matmul.jl | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 67275370b2a4d..0ca0527785d61 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -784,7 +784,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end - if sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) + if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in BxN, k in BxK @@ -793,7 +793,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end end - elseif sizeof(R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose)) + elseif isbitstype(R) && sizeof(R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose)) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C t = wrapperop(A) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index af3709304bc87..a04c435a42edc 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -667,12 +667,9 @@ end import Base: *, adjoint, transpose import LinearAlgebra: Adjoint, Transpose (*)(x::RootInt, y::RootInt) = x.i * y.i +(*)(x::RootInt, y::Integer) = x.i * y adjoint(x::RootInt) = x transpose(x::RootInt) = x -Adjoint(x::RootInt) = x -Transpose(x::RootInt) = x -# TODO once Adjoint/Transpose constructors call adjoint/transpose recursively -# rather than Adjoint/Transpose, the additional definitions should become unnecessary @test Base.promote_op(*, RootInt, RootInt) === Int From 3c75cd5b6926a767a3b6d5553c504d3959c66418 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 10 Nov 2023 22:27:46 +0100 Subject: [PATCH 15/16] fix lazy conj case --- stdlib/LinearAlgebra/src/adjtrans.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 346ed29785a36..3df9e3e151f25 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -400,11 +400,16 @@ map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)... map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...)) quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below +quasiparentc(x) = parent(parent(x)); quasiparentc(x::Number) = conj(x) # to handle numbers in the defs below broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...)) broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...)) # Hack to preserve behavior after #32122; this needs to be done with a broadcast style instead to support dotted fusion Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...)) Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...)) +Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Transpose{<:Any,<:AdjointAbsVec}}...) = + transpose(adjoint(broadcast((xs...) -> adjoint(transpose(f(conj.(xs)...))), quasiparentc.(tvs)...))) +Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,Adjoint{<:Any,<:TransposeAbsVec}}...) = + adjoint(transpose(broadcast((xs...) -> transpose(adjoint(f(conj.(xs)...))), quasiparentc.(tvs)...))) # TODO unify and allow mixed combinations with a broadcast style From 1a60c0f6df5d97df0a4961d876379df8b8d73d33 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 11 Nov 2023 08:25:20 +0100 Subject: [PATCH 16/16] simplify test slightly --- stdlib/LinearAlgebra/test/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index a04c435a42edc..30cc74694b3f4 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -255,7 +255,7 @@ end @testset "mixed Blas-non-Blas matmul" begin AA = rand(-10:10, 6, 6) - BB = rand(Float64, 6, 6) + BB = ones(Float64, 6, 6) CC = zeros(Float64, 6, 6) for A in (copy(AA), view(AA, 1:6, 1:6)), B in (copy(BB), view(BB, 1:6, 1:6)), C in (copy(CC), view(CC, 1:6, 1:6)) @test LinearAlgebra.mul!(C, A, B) == A * B