From d37fed36e09ec7b9b22c967e52221058b4c8b9a4 Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 27 Mar 2025 15:25:17 +0100 Subject: [PATCH 01/20] add generic syrk/herk --- src/matmul.jl | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/matmul.jl | 16 ++++++++++ 2 files changed, 99 insertions(+) diff --git a/src/matmul.jl b/src/matmul.jl index 6eb9e9df..a4da0aeb 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -570,6 +570,89 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end +Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} + mA, nA = lapack_size(tA, A) + mB, nB = lapack_size(tB, B) + if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) + matmul_size_check(size(C), (mA, nA), (mB, nB)) + return _rmul_or_fill!(C, β) + end + matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C + + if A === B + tA_uc = uppercase(tA) # potentially strip a WrapperChar + aat = (tA_uc == 'N') + blasfn = _valtypeparam(val) + if blasfn == BlasFlag.SYRK && T <: Union{Real,Complex} && (iszero(β) || issymmetric(C)) + return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U') + elseif blasfn == BlasFlag.HERK && (iszero(β) || ishermitian(C)) + return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true) + end + end + + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) +end + +function generic_syrk!(C::StridedMatrix{T}, A::StridedMatrix{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} + nC = checksquare(C) + m, n = size(A) + mA = aat ? m : n + if nC != mA + throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) + end + + if iszero(β) + fill!(C, T(0)) + elseif !isone(β) + C .*= β + end + @inbounds if !conjugate + if aat + for k ∈ 1:n, j ∈ 1:m + αA_jk = α * A[j, k] + for i ∈ 1:j + C[i, j] += A[i, k] * αA_jk + end + end + else + for j ∈ 1:n, i ∈ 1:j + temp = A[1, i] * A[1, j] + for k ∈ 2:m + temp += A[k, i] * A[k, j] + end + C[i, j] += α * temp + end + end + else + if aat + for k ∈ 1:n, j ∈ 1:m + αA_jk_bar = α * conj(A[j, k]) + for i ∈ 1:j-1 + C[i, j] += A[i, k] * αA_jk_bar + end + C[j, j] += α * abs2(A[j, k]) + end + else + for j ∈ 1:n + for i ∈ 1:j-1 + temp = conj(A[1, i]) * A[1, j] + for k ∈ 2:m + temp += conj(A[k, i]) * A[k, j] + end + C[i, j] += α * temp + end + temp = abs2(A[1, j]) + for k ∈ 2:m + temp += abs2(A[k, j]) + end + C[j, j] += α * temp + end + end + end + return C +end + # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = diff --git a/test/matmul.jl b/test/matmul.jl index 86c75ae5..17846ac9 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -537,6 +537,22 @@ end @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5) end +@testset "generic syrk & herk" + for T ∈ (BigFloat, Complex{BigFloat}) + a = randn(T, 3, 4) + csmall = similar(a, 3, 3) + cbig = similar(a, 4, 4) + _generic_matmatmul!(csmall, a, a', true, false) + @test csmall ≈ a * a' + _generic_matmatmul!(csmall, a, transpose(a), true, false) + @test csmall ≈ a * transpose(a) + _generic_matmatmul!(cbig, a', a, true, false) + @test cbig ≈ a' * a + _generic_matmatmul!(cbig, transpose(a), a, true, false) + @test cbig ≈ tranpose(a) * a + end +end + @testset "matmul for types w/o sizeof (issue #1282)" begin AA = fill(complex(1, 1), 10, 10) for A in (copy(AA), view(AA, 1:10, 1:10)) From 89105e42f08c5c79555f0602013d28c8c6d9563c Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 27 Mar 2025 16:29:28 +0100 Subject: [PATCH 02/20] add vector --- src/matmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index a4da0aeb..e9b0d7f5 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -594,9 +594,9 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) end -function generic_syrk!(C::StridedMatrix{T}, A::StridedMatrix{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} +function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} nC = checksquare(C) - m, n = size(A) + m, n = size(A, 1), size(A, 2) mA = aat ? m : n if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) From f630431a38934cc84a16bbcabcbe28ad579c3953 Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 27 Mar 2025 17:08:21 +0100 Subject: [PATCH 03/20] typo --- test/matmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/matmul.jl b/test/matmul.jl index 17846ac9..3811a299 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -537,7 +537,7 @@ end @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5) end -@testset "generic syrk & herk" +@testset "generic syrk & herk" begin for T ∈ (BigFloat, Complex{BigFloat}) a = randn(T, 3, 4) csmall = similar(a, 3, 3) @@ -549,7 +549,7 @@ end _generic_matmatmul!(cbig, a', a, true, false) @test cbig ≈ a' * a _generic_matmatmul!(cbig, transpose(a), a, true, false) - @test cbig ≈ tranpose(a) * a + @test cbig ≈ transpose(a) * a end end From 68a327ebb557b9576339a2f07552b2d87a0b6ffb Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 27 Mar 2025 17:20:44 +0100 Subject: [PATCH 04/20] typo --- test/matmul.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/matmul.jl b/test/matmul.jl index 3811a299..7ab4fe9b 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -542,13 +542,13 @@ end a = randn(T, 3, 4) csmall = similar(a, 3, 3) cbig = similar(a, 4, 4) - _generic_matmatmul!(csmall, a, a', true, false) + LinearAlgebra._generic_matmatmul!(csmall, a, a', true, false) @test csmall ≈ a * a' - _generic_matmatmul!(csmall, a, transpose(a), true, false) + LinearAlgebra._generic_matmatmul!(csmall, a, transpose(a), true, false) @test csmall ≈ a * transpose(a) - _generic_matmatmul!(cbig, a', a, true, false) + LinearAlgebra._generic_matmatmul!(cbig, a', a, true, false) @test cbig ≈ a' * a - _generic_matmatmul!(cbig, transpose(a), a, true, false) + LinearAlgebra._generic_matmatmul!(cbig, transpose(a), a, true, false) @test cbig ≈ transpose(a) * a end end From 7e151eafdc6b134edfab146a2aaa19a7146f7243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Thu, 27 Mar 2025 22:21:34 +0100 Subject: [PATCH 05/20] _rmul_or_full! Co-authored-by: Daniel Karrasch --- src/matmul.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index e9b0d7f5..c9d4a353 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -602,11 +602,7 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end - if iszero(β) - fill!(C, T(0)) - elseif !isone(β) - C .*= β - end + _rmul_or_fill!(C, β) @inbounds if !conjugate if aat for k ∈ 1:n, j ∈ 1:m From 5eff17ebcb07d4e0d5ece2391a6dbf1203c7a74c Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 27 Mar 2025 23:01:10 +0100 Subject: [PATCH 06/20] require one based indexing --- src/matmul.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/matmul.jl b/src/matmul.jl index c9d4a353..35e16d4f 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -595,6 +595,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix end function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} + require_one_based_indexing(C, A) nC = checksquare(C) m, n = size(A, 1), size(A, 2) mA = aat ? m : n From 47bde1e2055591e766789aa4942c7c3d2748960a Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 27 Mar 2025 23:02:11 +0100 Subject: [PATCH 07/20] simple optimization --- src/generic.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/generic.jl b/src/generic.jl index 2b03b249..49195585 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -280,6 +280,7 @@ julia> rmul!([NaN], 0.0) ``` """ function rmul!(X::AbstractArray, s::Number) + isone(s) && return X @simd for I in eachindex(X) @inbounds X[I] *= s end @@ -318,6 +319,7 @@ julia> lmul!(0.0, [Inf]) ``` """ function lmul!(s::Number, X::AbstractArray) + isone(s) && return X @simd for I in eachindex(X) @inbounds X[I] = s*X[I] end From b236f42e1c07cf6395d30c74f9243037f905e47a Mon Sep 17 00:00:00 2001 From: araujoms Date: Fri, 28 Mar 2025 10:22:38 +0100 Subject: [PATCH 08/20] still better than the fallback --- src/matmul.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 35e16d4f..f1d8264e 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -772,7 +772,7 @@ end # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it # to be concretely inferred Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, - alpha::Number, beta::Number) where {T<:BlasFloat} + α::Number, β::Number) where {T<:BlasFloat} nC = checksquare(C) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'T' @@ -788,16 +788,18 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst # BLAS.syrk! only updates symmetric C # alternatively, make non-zero β a show-stopper for BLAS.syrk! - if iszero(beta) || issymmetric(C) - α, β = promote(alpha, beta, zero(T)) + if iszero(β) || issymmetric(C) + alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && _fullstride2(A) && _fullstride2(C)) return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') + else + return copytri!(generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta), 'U') end end - return gemm_wrapper!(C, tA, tAt, A, A, alpha, beta) + return gemm_wrapper!(C, tA, tAt, A, A, α, β) end # legacy method syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = @@ -830,6 +832,8 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St stride(A, 1) == stride(C, 1) == 1 && _fullstride2(A) && _fullstride2(C)) return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) + else + return copytri!(generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta), 'U', true) end end return gemm_wrapper!(C, tA, tAt, A, A, α, β) From b1e1c192fc91547577392358fc6bc184e6729172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Mon, 31 Mar 2025 11:38:54 +0200 Subject: [PATCH 09/20] multiply from the right Co-authored-by: Daniel Karrasch --- src/matmul.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index f1d8264e..9ea0a412 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -607,7 +607,7 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo @inbounds if !conjugate if aat for k ∈ 1:n, j ∈ 1:m - αA_jk = α * A[j, k] + αA_jk = A[j, k] * α for i ∈ 1:j C[i, j] += A[i, k] * αA_jk end @@ -618,17 +618,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo for k ∈ 2:m temp += A[k, i] * A[k, j] end - C[i, j] += α * temp + C[i, j] += temp * α end end else if aat for k ∈ 1:n, j ∈ 1:m - αA_jk_bar = α * conj(A[j, k]) + αA_jk_bar = conj(A[j, k]) * α for i ∈ 1:j-1 C[i, j] += A[i, k] * αA_jk_bar end - C[j, j] += α * abs2(A[j, k]) + C[j, j] += abs2(A[j, k]) * α end else for j ∈ 1:n @@ -637,13 +637,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo for k ∈ 2:m temp += conj(A[k, i]) * A[k, j] end - C[i, j] += α * temp + C[i, j] += temp * α end temp = abs2(A[1, j]) for k ∈ 2:m temp += abs2(A[k, j]) end - C[j, j] += α * temp + C[j, j] += temp * α end end end From 47e0a03f059098658cdd5f02a5255284f7715b7b Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 31 Mar 2025 12:00:41 +0200 Subject: [PATCH 10/20] no benefit to calling the generic version here --- src/matmul.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/matmul.jl b/src/matmul.jl index 9ea0a412..6f52148d 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -578,7 +578,6 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix matmul_size_check(size(C), (mA, nA), (mB, nB)) return _rmul_or_fill!(C, β) end - matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C if A === B tA_uc = uppercase(tA) # potentially strip a WrapperChar From 40e97cb89a2a202f383580a5137640ba3bad41b7 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 31 Mar 2025 12:19:49 +0200 Subject: [PATCH 11/20] add test for Quaternions --- test/generic.jl | 25 +++++++++++++++++++++++++ test/matmul.jl | 16 ---------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/test/generic.jl b/test/generic.jl index 6d11ec82..c729e4be 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -123,6 +123,31 @@ end @test_throws DimensionMismatch axpy!(α, x, Vector(1:3), y, Vector(1:5)) end +@testset "generic syrk & herk" begin + for T ∈ (BigFloat, Complex{BigFloat}, Quaternion{Float64}) + α = randn(T) + a = randn(T, 3, 4) + csmall = similar(a, 3, 3) + csmall_fallback = similar(a, 3, 3) + cbig = similar(a, 4, 4) + cbig_fallback = similar(a, 4, 4) + mul!(csmall, a, a', real(α), false) + LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', real(α), false) + @test ishermitian(csmall) + @test csmall ≈ csmall_fallback + mul!(cbig, a', a, real(α), false) + LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, real(α), false) + @test ishermitian(cbig) + @test cbig ≈ cbig_fallback + mul!(csmall, a, transpose(a), α, false) + LinearAlgebra._generic_matmatmul!(csmall_fallback, a, transpose(a), α, false) + @test csmall ≈ csmall_fallback + mul!(cbig, transpose(a), a, α, false) + LinearAlgebra._generic_matmatmul!(cbig_fallback, transpose(a), a, α, false) + @test cbig ≈ cbig_fallback + end +end + @test !issymmetric(fill(1,5,3)) @test !ishermitian(fill(1,5,3)) @test (x = fill(1,3); cross(x,x) == zeros(3)) diff --git a/test/matmul.jl b/test/matmul.jl index 7ab4fe9b..86c75ae5 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -537,22 +537,6 @@ end @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5) end -@testset "generic syrk & herk" begin - for T ∈ (BigFloat, Complex{BigFloat}) - a = randn(T, 3, 4) - csmall = similar(a, 3, 3) - cbig = similar(a, 4, 4) - LinearAlgebra._generic_matmatmul!(csmall, a, a', true, false) - @test csmall ≈ a * a' - LinearAlgebra._generic_matmatmul!(csmall, a, transpose(a), true, false) - @test csmall ≈ a * transpose(a) - LinearAlgebra._generic_matmatmul!(cbig, a', a, true, false) - @test cbig ≈ a' * a - LinearAlgebra._generic_matmatmul!(cbig, transpose(a), a, true, false) - @test cbig ≈ transpose(a) * a - end -end - @testset "matmul for types w/o sizeof (issue #1282)" begin AA = fill(complex(1, 1), 10, 10) for A in (copy(AA), view(AA, 1:10, 1:10)) From 38f61d8d73b1330d03a5ba61e7cecdc45fb56dc2 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 31 Mar 2025 12:24:07 +0200 Subject: [PATCH 12/20] src/matmul.jl --- src/matmul.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/matmul.jl b/src/matmul.jl index 6f52148d..d90ac81c 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -593,6 +593,8 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) end +#if conjugate is true, computes A A' α + C β if aat is true, and A' A α + C β otherwise +#if conjugate is false, computes A transpose(A) α + C β if aat is true, and tranpose(A) A α + C β otherwise function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} require_one_based_indexing(C, A) nC = checksquare(C) From ec06c723af615c26e30836694581e7805ec67754 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 31 Mar 2025 13:26:40 +0200 Subject: [PATCH 13/20] add docstring --- src/matmul.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index d90ac81c..8226b6ee 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -593,8 +593,14 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) end -#if conjugate is true, computes A A' α + C β if aat is true, and A' A α + C β otherwise -#if conjugate is false, computes A transpose(A) α + C β if aat is true, and tranpose(A) A α + C β otherwise +""" + generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} + +Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e., +``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise. +If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and +``A' A α + C β`` otherwise. +""" function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} require_one_based_indexing(C, A) nC = checksquare(C) From 9ec3005a8abf53a73f07b0e238b3848389bdf408 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 31 Mar 2025 15:39:55 +0200 Subject: [PATCH 14/20] typo --- src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matmul.jl b/src/matmul.jl index 8226b6ee..04f99a3e 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -832,7 +832,7 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St # Result array does not need to be initialized as long as beta==0 # C = Matrix{T}(undef, mA, mA) - if iszero(β) || issymmetric(C) + if iszero(β) || ishermitian(C) alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && From e91c1c41565c100830f52862693be584f16f8aa3 Mon Sep 17 00:00:00 2001 From: araujoms Date: Wed, 2 Apr 2025 11:46:54 +0200 Subject: [PATCH 15/20] =?UTF-8?q?fix=20non-real=20=CE=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/matmul.jl | 2 +- test/generic.jl | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/matmul.jl b/src/matmul.jl index 04f99a3e..42710d29 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -585,7 +585,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix blasfn = _valtypeparam(val) if blasfn == BlasFlag.SYRK && T <: Union{Real,Complex} && (iszero(β) || issymmetric(C)) return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U') - elseif blasfn == BlasFlag.HERK && (iszero(β) || ishermitian(C)) + elseif blasfn == BlasFlag.HERK && isreal(α) && isreal(β) && (iszero(β) || ishermitian(C)) return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true) end end diff --git a/test/generic.jl b/test/generic.jl index c729e4be..dc9eabc8 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -145,6 +145,17 @@ end mul!(cbig, transpose(a), a, α, false) LinearAlgebra._generic_matmatmul!(cbig_fallback, transpose(a), a, α, false) @test cbig ≈ cbig_fallback + if T <: Union{Real, Complex} + @test issymmetric(csmall) + @test issymmetric(cbig) + end + #make sure generic herk is not called for non-real α + mul!(csmall, a, a', α, false) + LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', α, false) + @test csmall ≈ csmall_fallback + mul!(cbig, a', a, α, false) + LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, α, false) + @test cbig ≈ cbig_fallback end end From c1630b6d7c13a1c67add1c190b1d4f3edfbbb425 Mon Sep 17 00:00:00 2001 From: araujoms Date: Wed, 2 Apr 2025 20:50:06 +0200 Subject: [PATCH 16/20] hang generic code lower in the dispatch hierarchy --- src/matmul.jl | 60 ++++++++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index e026ac53..5674e941 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -488,14 +488,13 @@ end # THE one big BLAS dispatch. This is split into two methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat} + α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) matmul_size_check(size(C), (mA, nA), (mB, nB)) return _rmul_or_fill!(C, β) end - matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val) return C end @@ -570,29 +569,6 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end -Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} - mA, nA = lapack_size(tA, A) - mB, nB = lapack_size(tB, B) - if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) - matmul_size_check(size(C), (mA, nA), (mB, nB)) - return _rmul_or_fill!(C, β) - end - - if A === B - tA_uc = uppercase(tA) # potentially strip a WrapperChar - aat = (tA_uc == 'N') - blasfn = _valtypeparam(val) - if blasfn == BlasFlag.SYRK && T <: Union{Real,Complex} && (iszero(β) || issymmetric(C)) - return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U') - elseif blasfn == BlasFlag.HERK && isreal(α) && isreal(β) && (iszero(β) || ishermitian(C)) - return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true) - end - end - - return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) -end - """ generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} @@ -800,7 +776,8 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && - _fullstride2(A) && _fullstride2(C)) + _fullstride2(A) && _fullstride2(C)) && + max(nA, mA) ≥ 4 return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') else return copytri!(generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta), 'U') @@ -808,6 +785,16 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end +Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:Number} + + tA_uc = uppercase(tA) # potentially strip a WrapperChar + aat = (tA_uc == 'N') + if T <: Union{Real,Complex} && (iszero(β) || issymmetric(C)) + return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U') + end + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) +end # legacy method syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = syrk_wrapper!(C, tA, A, _add.alpha, _add.beta) @@ -837,7 +824,8 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && - _fullstride2(A) && _fullstride2(C)) + _fullstride2(A) && _fullstride2(C)) && + max(nA, mA) ≥ 4 return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) else return copytri!(generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta), 'U', true) @@ -845,6 +833,17 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end +Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:Number} + + tA_uc = uppercase(tA) # potentially strip a WrapperChar + aat = (tA_uc == 'N') + + if isreal(α) && isreal(β) && (iszero(β) || ishermitian(C)) + return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true) + end + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) +end # legacy method herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = @@ -896,6 +895,13 @@ end gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta) +# fallback for generic types +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:Number} + matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) +end # Aggressive constprop helps propagate the values of tA and tB into wrap, which # makes the calls concretely inferred From b979b01640af39a646c578401f7cb5b7dd57e976 Mon Sep 17 00:00:00 2001 From: araujoms Date: Wed, 2 Apr 2025 22:00:49 +0200 Subject: [PATCH 17/20] typos --- src/matmul.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 5674e941..d30cb90b 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -793,7 +793,8 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst if T <: Union{Real,Complex} && (iszero(β) || issymmetric(C)) return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U') end - return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) + tAt = aat ? 'T' : 'N' + return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β) end # legacy method syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = @@ -838,11 +839,11 @@ Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{T}, tA::Abst tA_uc = uppercase(tA) # potentially strip a WrapperChar aat = (tA_uc == 'N') - if isreal(α) && isreal(β) && (iszero(β) || ishermitian(C)) return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true) end - return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) + tAt = aat ? 'C' : 'N' + return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β) end # legacy method herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, @@ -877,6 +878,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab mB, nB = lapack_size(tB, B) matmul_size_check(size(C), (mA, nA), (mB, nB)) + matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C if C === A || B === C throw(ArgumentError("output matrix must not be aliased with input matrix")) From a82a03df78696ad7dace3bab8b3ee5aa3c7b5b94 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 14 Apr 2025 15:52:13 +0200 Subject: [PATCH 18/20] update docstring --- src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matmul.jl b/src/matmul.jl index d30cb90b..1d29381c 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -575,7 +575,7 @@ end Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e., ``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise. If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and -``A' A α + C β`` otherwise. +``A' A α + C β`` otherwise. Only the upper triangular is computed. """ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} require_one_based_indexing(C, A) From fb95a041eccc27421288dbad69110f2eb8b11334 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Mon, 14 Apr 2025 20:12:09 +0200 Subject: [PATCH 19/20] Update src/matmul.jl Co-authored-by: Jishnu Bhattacharya --- src/matmul.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 1d29381c..0e8b3cbc 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -778,10 +778,12 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst stride(A, 1) == stride(C, 1) == 1 && _fullstride2(A) && _fullstride2(C)) && max(nA, mA) ≥ 4 - return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') + BLAS.syrk!('U', tA, alpha, A, beta, C) else - return copytri!(generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta), 'U') + generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta) end + copytri!(C, 'U') + return C end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end From ada85984a681c521d4b5d7c6dfbe6038e59f8743 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 14 Apr 2025 20:18:29 +0200 Subject: [PATCH 20/20] update herk for consistency --- src/matmul.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index 0e8b3cbc..58a1609e 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -782,8 +782,7 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst else generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta) end - copytri!(C, 'U') - return C + return copytri!(C, 'U') end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end @@ -829,10 +828,11 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St stride(A, 1) == stride(C, 1) == 1 && _fullstride2(A) && _fullstride2(C)) && max(nA, mA) ≥ 4 - return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) + BLAS.herk!('U', tA, alpha, A, beta, C) else - return copytri!(generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta), 'U', true) + generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta) end + return copytri!(C, 'U', true) end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end