Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add

# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
# to be concretely inferred
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
α::Number, β::Number) where {T<:BlasReal}
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC},
α::Number, β::Number) where {TC<:BlasComplex}
T = real(TC)
nC = checksquare(C)
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if tA_uc == 'C'
Expand All @@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
end

# Result array does not need to be initialized as long as beta==0
# C = Matrix{T}(undef, mA, mA)

# BLAS.herk! only updates hermitian C, alpha and beta need to be real
if iszero(β) || ishermitian(C)
alpha, beta = promote(α, β, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
if (alpha isa T && beta isa T &&
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
Expand Down
12 changes: 11 additions & 1 deletion test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,17 @@ end

A5x5, A6x5 = Matrix{Float64}.(undef, ((5, 5), (6, 5)))
@test_throws DimensionMismatch LinearAlgebra.syrk_wrapper!(A5x5, 'N', A6x5)
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5)
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(complex(A5x5), 'N', complex(A6x5))
end

@testset "5-arg syrk! & herk!" begin
for T in (Float32, Float64, ComplexF32, ComplexF64), A in (randn(T, 5), randn(T, 5, 5))
B = A' * A
C = B isa Number ? [B;;] : Matrix(Hermitian(B))
@test mul!(copy(C), A', A, true, 2) ≈ 3C
D = Matrix(Hermitian(A * A'))
@test mul!(copy(D), A, A', true, 3) ≈ 4D
end
end

@testset "matmul for types w/o sizeof (issue #1282)" begin
Expand Down