Skip to content

Commit c646efc

Browse files
committed
Merge pull request #16615 from JuliaLang/anj/nomul
Avoid some unintended calls to generic_matmatmul! for special matrices
2 parents 3f59431 + 410471f commit c646efc

File tree

3 files changed

+116
-23
lines changed

3 files changed

+116
-23
lines changed

base/linalg/diagonal.jl

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,31 +104,85 @@ end
104104
/{T<:Number}(D::Diagonal, x::T) = Diagonal(D.diag / x)
105105
*(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .* Db.diag)
106106
*(D::Diagonal, V::AbstractVector) = D.diag .* V
107-
# To avoid ambiguity in the definitions below
108-
for uplo in (:LowerTriangular, :UpperTriangular)
109-
@eval begin
110-
(*)(A::$uplo, D::Diagonal) = $uplo(A.data * D)
111-
112-
function (*)(A::$(Symbol(:Unit, uplo)), D::Diagonal)
113-
B = A.data * D
114-
for i = 1:size(A, 1)
115-
B[i,i] = D.diag[i]
116-
end
117-
return $uplo(B)
118-
end
119-
end
120-
end
121-
(*)(A::AbstractTriangular, D::Diagonal) = error("this method should never be reached")
122-
(*)(D::Diagonal, A::AbstractTriangular) = error("this method should never be reached")
107+
108+
(*)(A::AbstractTriangular, D::Diagonal) = A_mul_B!(copy(A), D)
109+
(*)(D::Diagonal, B::AbstractTriangular) = A_mul_B!(D, copy(B))
123110

124111
(*)(A::AbstractMatrix, D::Diagonal) =
125112
scale!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D.diag)
126113
(*)(D::Diagonal, A::AbstractMatrix) =
127114
scale!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), D.diag, A)
128115

129-
A_mul_B!(A::Diagonal,B::AbstractMatrix) = scale!(A.diag,B)
130-
At_mul_B!(A::Diagonal,B::AbstractMatrix)= scale!(A.diag,B)
131-
Ac_mul_B!(A::Diagonal,B::AbstractMatrix)= scale!(conj(A.diag),B)
116+
A_mul_B!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal) =
117+
typeof(A)(A_mul_B!(A.data, D))
118+
function A_mul_B!(A::UnitLowerTriangular, D::Diagonal)
119+
A_mul_B!(A.data, D)
120+
for i = 1:size(A, 1)
121+
A.data[i,i] = D.diag[i]
122+
end
123+
LowerTriangular(A.data)
124+
end
125+
function A_mul_B!(A::UnitUpperTriangular, D::Diagonal)
126+
A_mul_B!(A.data, D)
127+
for i = 1:size(A, 1)
128+
A.data[i,i] = D.diag[i]
129+
end
130+
UpperTriangular(A.data)
131+
end
132+
function A_mul_B!(D::Diagonal, B::UnitLowerTriangular)
133+
A_mul_B!(D, B.data)
134+
for i = 1:size(B, 1)
135+
B.data[i,i] = D.diag[i]
136+
end
137+
LowerTriangular(B.data)
138+
end
139+
function A_mul_B!(D::Diagonal, B::UnitUpperTriangular)
140+
A_mul_B!(D, B.data)
141+
for i = 1:size(B, 1)
142+
B.data[i,i] = D.diag[i]
143+
end
144+
UpperTriangular(B.data)
145+
end
146+
147+
Ac_mul_B(A::AbstractTriangular, D::Diagonal) = A_mul_B!(ctranspose(A), D)
148+
function Ac_mul_B(A::AbstractMatrix, D::Diagonal)
149+
Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1)))
150+
ctranspose!(Ac, A)
151+
A_mul_B!(Ac, D)
152+
end
153+
154+
At_mul_B(A::AbstractTriangular, D::Diagonal) = A_mul_B!(transpose(A), D)
155+
function At_mul_B(A::AbstractMatrix, D::Diagonal)
156+
Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1)))
157+
transpose!(Ac, A)
158+
A_mul_B!(Ac, D)
159+
end
160+
161+
A_mul_Bc(D::Diagonal, B::AbstractTriangular) = A_mul_B!(D, ctranspose(B))
162+
A_mul_Bc(D::Diagonal, Q::Union{Base.LinAlg.QRCompactWYQ,Base.LinAlg.QRPackedQ}) = A_mul_Bc!(Array(D), Q)
163+
function A_mul_Bc(D::Diagonal, A::AbstractMatrix)
164+
Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1)))
165+
ctranspose!(Ac, A)
166+
A_mul_B!(D, Ac)
167+
end
168+
169+
A_mul_Bt(D::Diagonal, B::AbstractTriangular) = A_mul_B!(D, transpose(B))
170+
function A_mul_Bt(D::Diagonal, A::AbstractMatrix)
171+
Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1)))
172+
ctranspose!(Ac, A)
173+
A_mul_B!(D, Ac)
174+
end
175+
176+
A_mul_B!(A::Diagonal,B::Diagonal) = throw(MethodError(A_mul_B!, Tuple{Diagonal,Diagonal}))
177+
At_mul_B!(A::Diagonal,B::Diagonal) = throw(MethodError(At_mul_B!, Tuple{Diagonal,Diagonal}))
178+
Ac_mul_B!(A::Diagonal,B::Diagonal) = throw(MethodError(Ac_mul_B!, Tuple{Diagonal,Diagonal}))
179+
A_mul_B!(A::Base.LinAlg.QRPackedQ, D::Diagonal) = throw(MethodError(A_mul_B!, Tuple{Diagonal,Diagonal}))
180+
A_mul_B!(A::Diagonal,B::AbstractMatrix) = scale!(A.diag,B)
181+
At_mul_B!(A::Diagonal,B::AbstractMatrix) = scale!(A.diag,B)
182+
Ac_mul_B!(A::Diagonal,B::AbstractMatrix) = scale!(conj(A.diag),B)
183+
A_mul_B!(A::AbstractMatrix,B::Diagonal) = scale!(A,B.diag)
184+
A_mul_Bt!(A::AbstractMatrix,B::Diagonal) = scale!(A,B.diag)
185+
A_mul_Bc!(A::AbstractMatrix,B::Diagonal) = scale!(A,conj(B.diag))
132186

133187
/(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag )
134188
function A_ldiv_B!{T}(D::Diagonal{T}, v::AbstractVector{T})

base/linalg/lq.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,16 @@ end
106106
## Multiplication by Q
107107
### QB
108108
A_mul_B!{T<:BlasFloat}(A::LQPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormlq!('L','N',A.factors,A.τ,B)
109-
function *{TA,TB}(A::LQPackedQ{TA},B::StridedVecOrMat{TB})
110-
TAB = promote_type(TA, TB)
109+
function (*)(A::LQPackedQ,B::StridedVecOrMat)
110+
TAB = promote_type(eltype(A), eltype(B))
111111
A_mul_B!(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))
112112
end
113113

114114
### QcB
115115
Ac_mul_B!{T<:BlasReal}(A::LQPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormlq!('L','T',A.factors,A.τ,B)
116116
Ac_mul_B!{T<:BlasComplex}(A::LQPackedQ{T}, B::StridedVecOrMat{T}) = LAPACK.ormlq!('L','C',A.factors,A.τ,B)
117-
function Ac_mul_B{TA,TB}(A::LQPackedQ{TA}, B::StridedVecOrMat{TB})
118-
TAB = promote_type(TA,TB)
117+
function Ac_mul_B(A::LQPackedQ, B::StridedVecOrMat)
118+
TAB = promote_type(eltype(A), eltype(B))
119119
if size(B,1) == size(A.factors,2)
120120
Ac_mul_B!(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB))
121121
elseif size(B,1) == size(A.factors,1)
@@ -125,6 +125,19 @@ function Ac_mul_B{TA,TB}(A::LQPackedQ{TA}, B::StridedVecOrMat{TB})
125125
end
126126
end
127127

128+
### QBc/QcBc
129+
for (f1, f2) in ((:A_mul_Bc, :A_mul_B!),
130+
(:Ac_mul_Bc, :Ac_mul_B!))
131+
@eval begin
132+
function ($f1)(A::LQPackedQ, B::StridedVecOrMat)
133+
TAB = promote_type(eltype(A), eltype(B))
134+
BB = similar(B, TAB, (size(B, 2), size(B, 1)))
135+
ctranspose!(BB, B)
136+
return ($f2)(A, BB)
137+
end
138+
end
139+
end
140+
128141
### AQ
129142
A_mul_B!{T<:BlasFloat}(A::StridedMatrix{T}, B::LQPackedQ{T}) = LAPACK.ormlq!('R', 'N', B.factors, B.τ, A)
130143
function *{TA,TB}(A::StridedMatrix{TA},B::LQPackedQ{TB})
@@ -146,6 +159,19 @@ function A_mul_Bc{TA<:Number,TB<:Number}( A::StridedVecOrMat{TA}, B::LQPackedQ{T
146159
A_mul_Bc!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB},(B)))
147160
end
148161

162+
### AcQ/AcQc
163+
for (f1, f2) in ((:Ac_mul_B, :A_mul_B!),
164+
(:Ac_mul_Bc, :A_mul_Bc!))
165+
@eval begin
166+
function ($f1)(A::StridedMatrix, B::LQPackedQ)
167+
TAB = promote_type(eltype(A), eltype(B))
168+
AA = similar(A, TAB, (size(A, 2), size(A, 1)))
169+
ctranspose!(AA, A)
170+
return ($f2)(AA, B)
171+
end
172+
end
173+
end
174+
149175
function \{TA,Tb}(A::LQ{TA}, b::StridedVector{Tb})
150176
S = promote_type(TA,Tb)
151177
m = checksquare(A)

test/linalg/diagonal.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,16 @@ end
246246

247247
# Issue 15401
248248
@test eye(5) \ Diagonal(ones(5)) == eye(5)
249+
250+
# Triangular and Diagonal
251+
for T in (LowerTriangular(randn(5,5)), LinAlg.UnitLowerTriangular(randn(5,5)))
252+
D = Diagonal(randn(5))
253+
@test T'D == Array(T)'*Array(D)
254+
@test T.'D == Array(T).'*Array(D)
255+
@test D*T' == Array(D)*Array(T)'
256+
@test D*T.' == Array(D)*Array(T).'
257+
end
258+
259+
# Diagonal and Q
260+
Q = qrfact(randn(5,5))[:Q]
261+
@test D*Q' == Array(D)*Q'

0 commit comments

Comments
 (0)