From d052a621cd2eb0f467f5d88536860959caeddfef Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 17 Feb 2025 20:16:55 +0530 Subject: [PATCH 1/2] Indirection in matrix multiplication to avoid ambiguities --- src/diagonal.jl | 12 ++++++------ src/matmul.jl | 4 +++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diagonal.jl b/src/diagonal.jl index 9f8d54e5..8da03516 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -671,21 +671,21 @@ end for Tri in (:UpperTriangular, :LowerTriangular) UTri = Symbol(:Unit, Tri) # 2 args - for (fun, f) in zip((:*, :rmul!, :rdiv!, :/), (:identity, :identity, :inv, :inv)) + for (fun, f) in zip((:mul, :rmul!, :rdiv!, :/), (:identity, :identity, :inv, :inv)) @eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D)) @eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag)) end - @eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = + @eval mul(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = @invoke *(A::AbstractMatrix, D::Diagonal) - @eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = + @eval mul(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = @invoke *(A::AbstractMatrix, D::Diagonal) - for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv)) + for (fun, f) in zip((:mul, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv)) @eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data)) @eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag)) end - @eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) = + @eval mul(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) = @invoke *(D::Diagonal, A::AbstractMatrix) - @eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) = + @eval mul(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) = @invoke *(D::Diagonal, A::AbstractMatrix) # 3-arg ldiv! @eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data)) diff --git a/src/matmul.jl b/src/matmul.jl index 21c6e299..84a83d3a 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -111,7 +111,9 @@ julia> [1 1; 0 1] * [1 0; 1 1] 1 1 ``` """ -function (*)(A::AbstractMatrix, B::AbstractMatrix) +(*)(A::AbstractMatrix, B::AbstractMatrix) = mul(A, B) +# we add an extra level of indirection to avoid ambiguities in * +function mul(A::AbstractMatrix, B::AbstractMatrix) TS = promote_op(matprod, eltype(A), eltype(B)) mul!(matprod_dest(A, B, TS), A, B) end From 3e9afb896c6c4e12ce688d4c954dbbc6840eab24 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 18 Feb 2025 12:50:44 +0530 Subject: [PATCH 2/2] Invoke `mul` in strided triangular * Diagonal --- src/diagonal.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diagonal.jl b/src/diagonal.jl index 8da03516..aec54708 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -672,21 +672,23 @@ for Tri in (:UpperTriangular, :LowerTriangular) UTri = Symbol(:Unit, Tri) # 2 args for (fun, f) in zip((:mul, :rmul!, :rdiv!, :/), (:identity, :identity, :inv, :inv)) - @eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D)) - @eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag)) + g = fun == :mul ? :* : fun + @eval $fun(A::$Tri, D::Diagonal) = $Tri($g(A.data, D)) + @eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($g(A.data, D), $f, D.diag)) end @eval mul(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = - @invoke *(A::AbstractMatrix, D::Diagonal) + @invoke mul(A::AbstractMatrix, D::Diagonal) @eval mul(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = - @invoke *(A::AbstractMatrix, D::Diagonal) + @invoke mul(A::AbstractMatrix, D::Diagonal) for (fun, f) in zip((:mul, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv)) - @eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data)) - @eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag)) + g = fun == :mul ? :* : fun + @eval $fun(D::Diagonal, A::$Tri) = $Tri($g(D, A.data)) + @eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($g(D, A.data), $f, D.diag)) end @eval mul(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) = - @invoke *(D::Diagonal, A::AbstractMatrix) + @invoke mul(D::Diagonal, A::AbstractMatrix) @eval mul(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) = - @invoke *(D::Diagonal, A::AbstractMatrix) + @invoke mul(D::Diagonal, A::AbstractMatrix) # 3-arg ldiv! @eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data)) @eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))