diff --git a/src/matmul.jl b/src/matmul.jl index a7838df8..202ef763 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it # to be concretely inferred -Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, - α::Number, β::Number) where {T<:BlasReal} +Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC}, + α::Number, β::Number) where {TC<:BlasComplex} + T = real(TC) nC = checksquare(C) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'C' @@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end - # Result array does not need to be initialized as long as beta==0 - # C = Matrix{T}(undef, mA, mA) - + # BLAS.herk! only updates hermitian C, alpha and beta need to be real if iszero(β) || ishermitian(C) alpha, beta = promote(α, β, zero(T)) - if (alpha isa Union{Bool,T} && - beta isa Union{Bool,T} && + if (alpha isa T && beta isa T && stride(A, 1) == stride(C, 1) == 1 && _fullstride2(A) && _fullstride2(C)) return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) diff --git a/test/matmul.jl b/test/matmul.jl index 938fc3b7..c78e7c73 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -539,7 +539,17 @@ end A5x5, A6x5 = Matrix{Float64}.(undef, ((5, 5), (6, 5))) @test_throws DimensionMismatch LinearAlgebra.syrk_wrapper!(A5x5, 'N', A6x5) - @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5) + @test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(complex(A5x5), 'N', complex(A6x5)) +end + +@testset "5-arg syrk! & herk!" begin + for T in (Float32, Float64, ComplexF32, ComplexF64), A in (randn(T, 5), randn(T, 5, 5)) + B = A' * A + C = B isa Number ? [B;;] : Matrix(Hermitian(B)) + @test mul!(copy(C), A', A, true, 2) ≈ 3C + D = Matrix(Hermitian(A * A')) + @test mul!(copy(D), A, A', true, 3) ≈ 4D + end end @testset "matmul for types w/o sizeof (issue #1282)" begin