From 24120f1033c00a9b8b69290d39befbae78e323e3 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 1 Apr 2025 16:01:51 +0200 Subject: [PATCH 1/2] Clean up `herk_wrapper!` and add 5-arg tests --- src/matmul.jl | 12 +++++------- test/matmul.jl | 10 ++++++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index a7838df8..202ef763 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it # to be concretely inferred -Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, - α::Number, β::Number) where {T<:BlasReal} +Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC}, + α::Number, β::Number) where {TC<:BlasComplex} + T = real(TC) nC = checksquare(C) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'C' @@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end - # Result array does not need to be initialized as long as beta==0 - # C = Matrix{T}(undef, mA, mA) - + # BLAS.herk! only updates hermitian C, alpha and beta need to be real if iszero(β) || ishermitian(C) alpha, beta = promote(α, β, zero(T)) - if (alpha isa Union{Bool,T} && - beta isa Union{Bool,T} && + if (alpha isa T && beta isa T && stride(A, 1) == stride(C, 1) == 1 && _fullstride2(A) && _fullstride2(C)) return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) diff --git a/test/matmul.jl b/test/matmul.jl index 938fc3b7..197fc216 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -542,6 +542,16 @@ end @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5) end +@testset "5-arg syrk! & herk!" begin + for T in (Float32, Float64, ComplexF32, ComplexF64), A in (randn(T, 5), randn(T, 5, 5)) + B = A' * A + C = B isa Number ? [B;;] : Matrix(Hermitian(B)) + @test mul!(copy(C), A', A, true, 2) ≈ 3C + D = Matrix(Hermitian(A * A')) + @test mul!(copy(D), A, A', true, 3) ≈ 4D + end +end + @testset "matmul for types w/o sizeof (issue #1282)" begin AA = fill(complex(1, 1), 10, 10) for A in (copy(AA), view(AA, 1:10, 1:10)) From 39588e08ea26e945916ba3d6aaf79b21aa1bb1f8 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 1 Apr 2025 16:56:23 +0200 Subject: [PATCH 2/2] fix dimensionmismatch test --- test/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/matmul.jl b/test/matmul.jl index 197fc216..c78e7c73 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -539,7 +539,7 @@ end A5x5, A6x5 = Matrix{Float64}.(undef, ((5, 5), (6, 5))) @test_throws DimensionMismatch LinearAlgebra.syrk_wrapper!(A5x5, 'N', A6x5) - @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5) + @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(complex(A5x5), 'N', complex(A6x5)) end @testset "5-arg syrk! & herk!" begin