diff --git a/src/triangular.jl b/src/triangular.jl index a040e80b..e3570cc6 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -17,9 +17,12 @@ @inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVecOrMat}) = transpose(transpose(B) * transpose(A)) +const StaticULT = Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}} + @inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::StaticVecOrMat) = _A_mul_B(Size(A), Size(B), A, B) @inline Base.:*(A::StaticVecOrMat, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = _A_mul_B(Size(A), Size(B), A, B) -@inline Base.:\(A::Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B) +@inline Base.:*(A::StaticULT, B::StaticULT) = _A_mul_B(Size(A), Size(B), A, B) +@inline Base.:\(A::StaticULT, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B) @generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} @@ -31,7 +34,7 @@ X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for j = 1:n for i = 1:m ex = :(A.data[$(LinearIndices(sa)[i, i])]*B[$(LinearIndices(sb)[i, j])]) @@ -59,7 +62,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for j = 1:n for i = m:-1:1 ex = :(A.data[$(LinearIndices(sa)[i, i])]'*B[$(LinearIndices(sb)[i, j])]) @@ -87,7 +90,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for j = 1:n for i = m:-1:1 ex = :(transpose(A.data[$(LinearIndices(sa)[i, i])])*B[$(LinearIndices(sb)[i, j])]) @@ -115,7 +118,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for j = 1:n for i = m:-1:1 ex = :(A.data[$(LinearIndices(sa)[i, i])]*B[$(LinearIndices(sb)[i, j])]) @@ -143,7 +146,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for j = 1:n for i = 1:m ex = :(A.data[$(LinearIndices(sa)[i, i])]'*B[$(LinearIndices(sb)[i, j])]) @@ -171,7 +174,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for j = 1:n for i = 1:m ex = :(transpose(A.data[$(LinearIndices(sa)[i, i])])*B[$(LinearIndices(sb)[i, j])]) @@ -203,7 +206,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for i = 1:m for j = n:-1:1 ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]) @@ -235,7 +238,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for i = 1:m for j = 1:n ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]') @@ -262,7 +265,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for i = 1:m for j = 1:n ex = :(A[$(LinearIndices(sa)[i, j])]*transpose(B[$(LinearIndices(sb)[j, j])])) @@ -294,7 +297,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for i = 1:m for j = 1:n ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]) @@ -326,7 +329,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for i = 1:m for j = n:-1:1 ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]') @@ -353,7 +356,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for i = 1:m for j = n:-1:1 ex = :(A[$(LinearIndices(sa)[i, j])]*transpose(B[$(LinearIndices(sb)[j, j])])) @@ -382,7 +385,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] init = [:($(X[i,j]) = B[$(LinearIndices(sb)[i, j])]) for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for k = 1:n for j = m:-1:1 if k == 1 @@ -414,7 +417,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] init = [:($(X[i,j]) = B[$(LinearIndices(sb)[i, j])]) for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for k = 1:n for j = 1:m if k == 1 @@ -445,7 +448,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for k = 1:n for j = 1:m ex = :(B[$(LinearIndices(sb)[j, k])]) @@ -476,7 +479,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for k = 1:n for j = 1:m ex = :(B[$(LinearIndices(sb)[j, k])]) @@ -507,7 +510,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for k = 1:n for j = m:-1:1 ex = :(B[$(LinearIndices(sb)[j, k])]) @@ -538,7 +541,7 @@ end X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - code = quote end + code = Expr(:block) for k = 1:n for j = m:-1:1 ex = :(B[$(LinearIndices(sb)[j, k])]) @@ -559,3 +562,129 @@ end @inbounds return similar_type(B, TAB)(tuple($(X...))) end end + +@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} + n = sa[1] + if n != sb[1] + throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) + end + + X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] + + TAB = promote_op(*, eltype(TA), eltype(TB)) + z = zero(TAB) + + code = Expr(:block) + for j = 1:n + for i = 1:n + if i > j + push!(code.args, :($(X[i,j]) = $z)) + else + ex = :(A.data[$(LinearIndices(sa)[i,i])] * B.data[$(LinearIndices(sb)[i,j])]) + for k = i+1:j + ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) + end + push!(code.args, :($(X[i,j]) = $ex)) + end + end + end + + return quote + @_inline_meta + @inbounds $code + return UpperTriangular(similar_type(B.data, $TAB)(tuple($(X...)))) + end + +end + +@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} + n = sa[1] + if n != sb[1] + throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) + end + + X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] + + TAB = promote_op(*, eltype(TA), eltype(TB)) + z = zero(TAB) + + code = Expr(:block) + for j = 1:n + for i = 1:n + if i < j + push!(code.args, :($(X[i,j]) = $z)) + else + ex = :(A.data[$(LinearIndices(sa)[i,j])] * B.data[$(LinearIndices(sb)[j,j])]) + for k = j+1:i + ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) + end + push!(code.args, :($(X[i,j]) = $ex)) + end + end + end + + return quote + @_inline_meta + @inbounds $code + return LowerTriangular(similar_type(B.data, $TAB)(tuple($(X...)))) + end + +end + + +@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} + n = sa[1] + if n != sb[1] + throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) + end + + X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] + + code = Expr(:block) + for j = 1:n + for i = 1:n + k1 = max(i,j) + ex = :(A.data[$(LinearIndices(sa)[i,k1])] * B.data[$(LinearIndices(sb)[k1,j])]) + for k = k1+1:n + ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) + end + push!(code.args, :($(X[i,j]) = $ex)) + end + end + + return quote + @_inline_meta + @inbounds $code + TAB = promote_op(*, eltype(TA), eltype(TB)) + return similar_type(B.data, TAB)(tuple($(X...))) + end + +end + +@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} + n = sa[1] + if n != sb[1] + throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) + end + + X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] + + code = Expr(:block) + for j = 1:n + for i = 1:n + ex = :(A.data[$(LinearIndices(sa)[i,1])] * B.data[$(LinearIndices(sb)[1,j])]) + for k = 2:min(i,j) + ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) + end + push!(code.args, :($(X[i,j]) = $ex)) + end + end + + return quote + @_inline_meta + @inbounds $code + TAB = promote_op(*, eltype(TA), eltype(TB)) + return similar_type(B.data, TAB)(tuple($(X...))) + end + +end diff --git a/test/triangular.jl b/test/triangular.jl index 7afe78a7..2992161b 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -82,6 +82,29 @@ end end end +@testset "Triangular-triangular multiplication" begin + for n in (1, 2, 3, 4), + eltyA in (Float64, ComplexF64, Int), + eltyB in (Float64, ComplexF64, Int), + (ta, uploa) in ((UpperTriangular, :U), (LowerTriangular, :L)), + (tb, uplob) in ((UpperTriangular, :U), (LowerTriangular, :L)) + + A = ta(eltyA == Int ? rand(1:7, n, n) : rand(eltyA, n, n)) + B = tb(eltyB == Int ? rand(1:7, n, n) : rand(eltyB, n, n)) + + SA = ta(SMatrix{n,n}(A.data)) + SB = tb(SMatrix{n,n}(B.data)) + + eltyAB = Base.promote_op(*, eltyA, eltyB) + + @test SA*SB ≈ A*B + @test eltype(SA*SB) == eltyAB + @test SA*SB isa (ta===tb ? ta : SMatrix) + + end + +end + @testset "Triangular-matrix division" begin for n in (1, 2, 3, 4), eltyA in (Float64, ComplexF64, Int),