diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 3b217658fbc52..ab39b6139fb29 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -204,13 +204,16 @@ end *(u::AdjointAbsVec, v::AdjointAbsVec) = throw(MethodError(*, (u, v))) *(u::TransposeAbsVec, v::TransposeAbsVec) = throw(MethodError(*, (u, v))) -# Adjoint/Transpose-vector * matrix -*(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) * u.parent) -*(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) * u.parent) -# Adjoint/Transpose-vector * Adjoint/Transpose-matrix -*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(A.parent * u.parent) -*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = transpose(A.parent * u.parent) - +# AdjOrTransAbsVec{<:Any,<:AdjOrTransAbsVec} is a lazy conj vectors +# We need to expand the combinations to avoid ambiguities +(*)(u::TransposeAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) = + sum(uu*vv for (uu, vv) in zip(u, v)) +(*)(u::AdjointAbsVec, v::AdjointAbsVec{<:Any,<:TransposeAbsVec}) = + sum(uu*vv for (uu, vv) in zip(u, v)) +(*)(u::TransposeAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) = + sum(uu*vv for (uu, vv) in zip(u, v)) +(*)(u::AdjointAbsVec, v::TransposeAbsVec{<:Any,<:AdjointAbsVec}) = + sum(uu*vv for (uu, vv) in zip(u, v)) ## pseudoinversion pinv(v::AdjointAbsVec, tol::Real = 0) = pinv(v.parent, tol).parent @@ -226,16 +229,3 @@ pinv(v::TransposeAbsVec, tol::Real = 0) = pinv(conj(v.parent)).parent /(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) \ u.parent) /(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = adjoint(conj(A.parent) \ u.parent) # technically should be adjoint(copy(adjoint(copy(A))) \ u.parent) /(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = transpose(conj(A.parent) \ u.parent) # technically should be transpose(copy(transpose(copy(A))) \ u.parent) - -# dismabiguation methods -*(A::AdjointAbsVec, B::Transpose{<:Any,<:AbstractMatrix}) = A * copy(B) -*(A::TransposeAbsVec, B::Adjoint{<:Any,<:AbstractMatrix}) = A * copy(B) -*(A::Transpose{<:Any,<:AbstractMatrix}, B::Adjoint{<:Any,<:AbstractMatrix}) = copy(A) * B -*(A::Adjoint{<:Any,<:AbstractMatrix}, B::Transpose{<:Any,<:AbstractMatrix}) = A * copy(B) -# Adj/Trans-vector * Trans/Adj-vector, shouldn't exist, here for ambiguity resolution? TODO: test removal -*(A::Adjoint{<:Any,<:AbstractVector}, B::Transpose{<:Any,<:AbstractVector}) = throw(MethodError(*, (A, B))) -*(A::Transpose{<:Any,<:AbstractVector}, B::Adjoint{<:Any,<:AbstractVector}) = throw(MethodError(*, (A, B))) -# Adj/Trans-matrix * Trans/Adj-vector, shouldn't exist, here for ambiguity resolution? TODO: test removal -*(A::Adjoint{<:Any,<:AbstractMatrix}, B::Adjoint{<:Any,<:AbstractVector}) = throw(MethodError(*, (A, B))) -*(A::Adjoint{<:Any,<:AbstractMatrix}, B::Transpose{<:Any,<:AbstractVector}) = throw(MethodError(*, (A, B))) -*(A::Transpose{<:Any,<:AbstractMatrix}, B::Adjoint{<:Any,<:AbstractVector}) = throw(MethodError(*, (A, B))) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 1c5bcbf027322..b87024bbffeb2 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -35,8 +35,10 @@ function dot(x::Vector{T}, rx::Union{UnitRange{TI},AbstractRange{TI}}, y::Vector GC.@preserve x y BLAS.dotc(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry)) end -*(transx::Transpose{<:Any,<:StridedVector{T}}, y::StridedVector{T}) where {T<:BlasComplex} = - (x = transx.parent; BLAS.dotu(x, y)) +function *(transx::Transpose{<:Any,<:StridedVector{T}}, y::StridedVector{T}) where {T<:BlasComplex} + x = transx.parent + return BLAS.dotu(x, y) +end # Matrix-vector multiplication function (*)(A::StridedMatrix{T}, x::StridedVector{S}) where {T<:BlasFloat,S} @@ -49,10 +51,14 @@ function (*)(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T,S} end # these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case): -*(a::AbstractVector, transB::Transpose{<:Any,<:AbstractMatrix}) = - (B = transB.parent; *(reshape(a,length(a),1), transpose(B))) -*(a::AbstractVector, adjB::Adjoint{<:Any,<:AbstractMatrix}) = - (B = adjB.parent; *(reshape(a,length(a),1), adjoint(B))) +function *(a::AbstractVector, transB::Transpose{<:Any,<:AbstractMatrix}) + B = transB.parent + reshape(a,length(a),1)*transpose(B) +end +function *(a::AbstractVector, adjB::Adjoint{<:Any,<:AbstractMatrix}) + B = adjB.parent + reshape(a,length(a),1)*adjoint(B) +end (*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where {T<:BlasFloat} = gemv!(y, 'N', A, x) @@ -78,10 +84,14 @@ function *(transA::Transpose{<:Any,<:AbstractMatrix{T}}, x::AbstractVector{S}) w TS = promote_op(matprod, T, S) mul!(similar(x,TS,size(A,2)), transpose(A), x) end -mul!(y::StridedVector{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}) where {T<:BlasFloat} = - (A = transA.parent; gemv!(y, 'T', A, x)) -mul!(y::AbstractVector, transA::Transpose{<:Any,<:AbstractVecOrMat}, x::AbstractVector) = - (A = transA.parent; generic_matvecmul!(y, 'T', A, x)) +function mul!(y::StridedVector{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}) where {T<:BlasFloat} + A = transA.parent + return gemv!(y, 'T', A, x) +end +function mul!(y::AbstractVector, transA::Transpose{<:Any,<:AbstractVecOrMat}, x::AbstractVector) + A = transA.parent + return generic_matvecmul!(y, 'T', A, x) +end function *(adjA::Adjoint{<:Any,<:StridedMatrix{T}}, x::StridedVector{S}) where {T<:BlasFloat,S} A = adjA.parent @@ -94,12 +104,22 @@ function *(adjA::Adjoint{<:Any,<:AbstractMatrix{T}}, x::AbstractVector{S}) where mul!(similar(x,TS,size(A,2)), adjoint(A), x) end -mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}) where {T<:BlasReal} = - (A = adjA.parent; mul!(y, transpose(A), x)) -mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}) where {T<:BlasComplex} = - (A = adjA.parent; gemv!(y, 'C', A, x)) -mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, x::AbstractVector) = - (A = adjA.parent; generic_matvecmul!(y, 'C', A, x)) +function mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}) where {T<:BlasReal} + A = adjA.parent + return mul!(y, transpose(A), x) +end +function mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}) where {T<:BlasComplex} + A = adjA.parent + return gemv!(y, 'C', A, x) +end +function mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, x::AbstractVector) + A = adjA.parent + return generic_matvecmul!(y, 'C', A, x) +end + +# Vector-Matrix multiplication +(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')' +(*)(x::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A)*transpose(x)) # Matrix-matrix multiplication @@ -165,23 +185,27 @@ Calculate the matrix-matrix product ``AB``, overwriting `B`, and return the resu """ lmul!(A, B) -function *(transA::Transpose{<:Any,<:AbstractMatrix}, B::AbstractMatrix) +function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasFloat} A = transA.parent - TS = promote_op(matprod, eltype(A), eltype(B)) - mul!(similar(B, TS, (size(A,2), size(B,2))), transpose(A), B) + if A===B + return syrk_wrapper!(C, 'T', A) + else + return gemm_wrapper!(C, 'T', 'N', A, B) + end +end +function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat) + A = transA.parent + return generic_matmatmul!(C, 'T', 'N', A, B) end -mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasFloat} = - (A = transA.parent; A===B ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)) -mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat) = - (A = transA.parent; generic_matmatmul!(C, 'T', 'N', A, B)) -function *(A::AbstractMatrix, transB::Transpose{<:Any,<:AbstractMatrix}) +function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} B = transB.parent - TS = promote_op(matprod, eltype(A), eltype(B)) - mul!(similar(B, TS, (size(A,1), size(B,1))), A, transpose(B)) + if A===B + return syrk_wrapper!(C, 'N', A) + else + return gemm_wrapper!(C, 'N', 'T', A, B) + end end -mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} = - (B = transB.parent; A===B ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)) for elty in (Float32,Float64) @eval begin function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, transB::Transpose{<:Any,<:StridedVecOrMat{$elty}}) @@ -195,64 +219,81 @@ for elty in (Float32,Float64) end # collapsing the following two defs with C::AbstractVecOrMat yields ambiguities mul!(C::AbstractVector, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat}) = - _disambigmul!(C, A, transB) + generic_matmatmul!(C, 'N', 'T', A, transB.parent) mul!(C::AbstractMatrix, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat}) = - _disambigmul!(C, A, transB) -_disambigmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat}) = - (B = transB.parent; generic_matmatmul!(C, 'N', 'T', A, B)) - -# collapsing the following two defs with transB::Transpose{<:Any,<:AbstractVecOrMat{S}} yields ambiguities -*(transA::Transpose{<:Any,<:AbstractMatrix}, transB::Transpose{<:Any,<:AbstractMatrix}) = - _disambigmul(transA, transB) -*(transA::Transpose{<:Any,<:AbstractMatrix}, transB::Transpose{<:Any,<:AbstractVector}) = - _disambigmul(transA, transB) -function _disambigmul(transA::Transpose{<:Any,<:AbstractMatrix{T}}, transB::Transpose{<:Any,<:AbstractVecOrMat{S}}) where {T,S} - A, B = transA.parent, transB.parent - TS = promote_op(matprod, T, S) - mul!(similar(B, TS, (size(A,2), size(B,1))), transpose(A), transpose(B)) -end -mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Transpose{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} = - (A = transA.parent; B = transB.parent; gemm_wrapper!(C, 'T', 'T', A, B)) -mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat}) = - (A = transA.parent; B = transB.parent; generic_matmatmul!(C, 'T', 'T', A, B)) -mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) - -*(adjA::Adjoint{<:Any,<:StridedMatrix{T}}, B::StridedMatrix{T}) where {T<:BlasReal} = - (A = adjA.parent; *(transpose(A), B)) -mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasReal} = - (A = adjA.parent; mul!(C, transpose(A), B)) -function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, B::AbstractMatrix) + generic_matmatmul!(C, 'N', 'T', A, transB.parent) + +function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Transpose{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} + A = transA.parent + B = transB.parent + return gemm_wrapper!(C, 'T', 'T', A, B) +end +function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat}) + A = transA.parent + B = transB.parent + return generic_matmatmul!(C, 'T', 'T', A, B) +end + +function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Adjoint{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} + A = transA.parent + B = transB.parent + return gemm_wrapper!(C, 'T', 'C', A, B) +end +function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Adjoint{<:Any,<:AbstractVecOrMat}) + A = transA.parent + B = transB.parent + return generic_matmatmul!(C, 'T', 'C', A, B) +end + +function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasReal} A = adjA.parent - TS = promote_op(matprod, eltype(A), eltype(B)) - mul!(similar(B, TS, (size(A,2), size(B,2))), adjoint(A), B) -end -mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasComplex} = - (A = adjA.parent; A===B ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)) -mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat) = - (A = adjA.parent; generic_matmatmul!(C, 'C', 'N', A, B)) - -*(A::StridedMatrix{<:BlasFloat}, adjB::Adjoint{<:Any,<:StridedMatrix{<:BlasReal}}) = - (B = adjB.parent; *(A, transpose(B))) -mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{<:BlasReal}}) where {T<:BlasFloat} = - (B = adjB.parent; mul!(C, A, transpose(B))) -function *(A::AbstractMatrix, adjB::Adjoint{<:Any,<:AbstractMatrix}) + return mul!(C, transpose(A), B) +end +function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}) where {T<:BlasComplex} + A = adjA.parent + if A===B + return herk_wrapper!(C,'C',A) + else + return gemm_wrapper!(C,'C', 'N', A, B) + end +end +function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat) + A = adjA.parent + return generic_matmatmul!(C, 'C', 'N', A, B) +end + +function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{<:BlasReal}}) where {T<:BlasFloat} B = adjB.parent - TS = promote_op(matprod, eltype(A), eltype(B)) - mul!(similar(B,TS,(size(A,1),size(B,1))), A, adjoint(B)) -end -mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasComplex} = - (B = adjB.parent; A===B ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)) -mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat}) = - (B = adjB.parent; generic_matmatmul!(C, 'N', 'C', A, B)) - -*(adjA::Adjoint{<:Any,<:AbstractMatrix}, adjB::Adjoint{<:Any,<:AbstractMatrix}) = - (A = adjA.parent; B = adjB.parent; mul!(similar(B, promote_op(matprod, eltype(A), eltype(B)), (size(A,2), size(B,1))), adjoint(A), adjoint(B))) -mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} = - (A = adjA.parent; B = adjB.parent; gemm_wrapper!(C, 'C', 'C', A, B)) -mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat}) = - (A = adjA.parent; B = adjB.parent; generic_matmatmul!(C, 'C', 'C', A, B)) -mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat}) = - (A = adjA.parent; B = transB.parent; generic_matmatmul!(C, 'C', 'T', A, B)) + return mul!(C, A, transpose(B)) +end +function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasComplex} + B = adjB.parent + if A === B + return herk_wrapper!(C, 'N', A) + else + return gemm_wrapper!(C, 'N', 'C', A, B) + end +end +function mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat}) + B = adjB.parent + return generic_matmatmul!(C, 'N', 'C', A, B) +end + +function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}) where {T<:BlasFloat} + A = adjA.parent + B = adjB.parent + return gemm_wrapper!(C, 'C', 'C', A, B) +end +function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat}) + A = adjA.parent + B = adjB.parent + return generic_matmatmul!(C, 'C', 'C', A, B) +end +function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat}) + A = adjA.parent + B = transB.parent + return generic_matmatmul!(C, 'C', 'T', A, B) +end # Supporting functions for matrix multiplication function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false) @@ -582,7 +623,6 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end else # Multiplication for non-plain-data uses the naive algorithm - if tA == 'N' if tB == 'N' for i = 1:mA, j = 1:nB @@ -595,7 +635,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end elseif tB == 'T' for i = 1:mA, j = 1:nB - z2 = zero(A[i, 1]*B[j, 1] + A[i, 1]*B[j, 1]) + z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1])) Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += A[i, k] * transpose(B[j, k]) @@ -604,7 +644,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end else for i = 1:mA, j = 1:nB - z2 = zero(A[i, 1]*B[j, 1] + A[i, 1]*B[j, 1]) + z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]') Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += A[i, k]*B[j, k]' @@ -615,7 +655,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat elseif tA == 'T' if tB == 'N' for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]*B[1, j] + A[1, i]*B[1, j]) + z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j]) Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += transpose(A[k, i]) * B[k, j] @@ -624,7 +664,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end elseif tB == 'T' for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]*B[j, 1] + A[1, i]*B[j, 1]) + z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1])) Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += transpose(A[k, i]) * transpose(B[j, k]) @@ -633,7 +673,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end else for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]*B[j, 1] + A[1, i]*B[j, 1]) + z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]') Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += transpose(A[k, i]) * adjoint(B[j, k]) @@ -644,7 +684,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat else if tB == 'N' for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]*B[1, j] + A[1, i]*B[1, j]) + z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j]) Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += A[k, i]'B[k, j] @@ -653,7 +693,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end elseif tB == 'T' for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]*B[j, 1] + A[1, i]*B[j, 1]) + z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1])) Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += adjoint(A[k, i]) * transpose(B[j, k]) @@ -662,7 +702,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat end else for i = 1:mA, j = 1:nB - z2 = zero(A[1, i]*B[j, 1] + A[1, i]*B[j, 1]) + z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]') Ctmp = convert(promote_type(R, typeof(z2)), z2) for k = 1:nA Ctmp += A[k, i]'B[j, k]' diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 5f3d09561bcbc..171265dfbd04c 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -402,4 +402,44 @@ module TestPR18218 @test d == TypeC[5, 11] end +@testset "VecOrMat of Vectors" begin + X = rand(ComplexF64, 3, 3) + Xv1 = [X[:,j] for i in 1:1, j in 1:3] + Xv2 = [transpose(X[i,:]) for i in 1:3] + Xv3 = [transpose(X[i,:]) for i in 1:3, j in 1:1] + + XX = X*X + XtX = transpose(X)*X + XcX = X'*X + XXt = X*transpose(X) + XtXt = transpose(XX) + XcXt = X'*transpose(X) + XXc = X*X' + XtXc = transpose(X)*X' + XcXc = X'*X' + + @test (Xv1*Xv2)[1] ≈ XX + @test (Xv1*Xv3)[1] ≈ XX + @test transpose(Xv1)*Xv1 ≈ XtX + @test transpose(Xv2)*Xv2 ≈ XtX + @test (transpose(Xv3)*Xv3)[1] ≈ XtX + @test Xv1'*Xv1 ≈ XcX + @test Xv2'*Xv2 ≈ XcX + @test (Xv3'*Xv3)[1] ≈ XcX + @test (Xv1*transpose(Xv1))[1] ≈ XXt + @test Xv2*transpose(Xv2) ≈ XXt + @test Xv3*transpose(Xv3) ≈ XXt + @test transpose(Xv1)*transpose(Xv2) ≈ XtXt + @test transpose(Xv1)*transpose(Xv3) ≈ XtXt + @test Xv1'*transpose(Xv2) ≈ XcXt + @test Xv1'*transpose(Xv3) ≈ XcXt + @test (Xv1*Xv1')[1] ≈ XXc + @test Xv2*Xv2' ≈ XXc + @test Xv3*Xv3' ≈ XXc + @test transpose(Xv1)*Xv2' ≈ XtXc + @test transpose(Xv1)*Xv3' ≈ XtXc + @test Xv1'*Xv2' ≈ XcXc + @test Xv1'*Xv3' ≈ XcXc +end + end # module TestMatmul