From 8914d409a56c01487782d96da896c75d8663f0cd Mon Sep 17 00:00:00 2001 From: Alexander Seiler Date: Sat, 14 Sep 2019 04:07:59 +0200 Subject: [PATCH] Reduce code duplication for dot product of symmetric/Hermitian matrices --- stdlib/LinearAlgebra/src/symmetric.jl | 107 +++++++++----------------- 1 file changed, 36 insertions(+), 71 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index ca401f3c1c441..307a5f51893f5 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -414,82 +414,47 @@ function triu(A::Symmetric, k::Integer=0) end end -function dot(A::Symmetric, B::Symmetric) - n = size(A, 2) - if n != size(B, 2) - throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) - end - - dotprod = zero(dot(first(A), first(B))) - @inbounds if A.uplo == 'U' && B.uplo == 'U' - for j in 1:n - for i in 1:(j - 1) - dotprod += 2 * dot(A.data[i, j], B.data[i, j]) - end - dotprod += dot(A[j, j], B[j, j]) - end - elseif A.uplo == 'L' && B.uplo == 'L' - for j in 1:n - dotprod += dot(A[j, j], B[j, j]) - for i in (j + 1):n - dotprod += 2 * dot(A.data[i, j], B.data[i, j]) - end - end - elseif A.uplo == 'U' && B.uplo == 'L' - for j in 1:n - for i in 1:(j - 1) - dotprod += 2 * dot(A.data[i, j], transpose(B.data[j, i])) - end - dotprod += dot(A[j, j], B[j, j]) - end - else - for j in 1:n - dotprod += dot(A[j, j], B[j, j]) - for i in (j + 1):n - dotprod += 2 * dot(A.data[i, j], transpose(B.data[j, i])) +for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:Hermitian, :adjoint, :real)] + @eval begin + function dot(A::$T, B::$T) + n = size(A, 2) + if n != size(B, 2) + throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) end - end - end - return dotprod -end -function dot(A::Hermitian, B::Hermitian) - n = size(A, 2) - if n != size(B, 2) - throw(DimensionMismatch("A has dimensions $(size(A)) but B has dimensions $(size(B))")) - end - - dotprod = zero(dot(first(A), first(B))) - @inbounds if A.uplo == 'U' && B.uplo == 'U' - for j in 1:n - for i in 1:(j - 1) - dotprod += 2 * real(dot(A.data[i, j], B.data[i, j])) - end - dotprod += dot(A[j, j], B[j, j]) - end - elseif A.uplo == 'L' && B.uplo == 'L' - for j in 1:n - dotprod += dot(A[j, j], B[j, j]) - for i in (j + 1):n - dotprod += 2 * real(dot(A.data[i, j], B.data[i, j])) - end - end - elseif A.uplo == 'U' && B.uplo == 'L' - for j in 1:n - for i in 1:(j - 1) - dotprod += 2 * real(dot(A.data[i, j], adjoint(B.data[j, i]))) - end - dotprod += dot(A[j, j], B[j, j]) - end - else - for j in 1:n - dotprod += dot(A[j, j], B[j, j]) - for i in (j + 1):n - dotprod += 2 * real(dot(A.data[i, j], adjoint(B.data[j, i]))) + dotprod = zero(dot(first(A), first(B))) + @inbounds if A.uplo == 'U' && B.uplo == 'U' + for j in 1:n + for i in 1:(j - 1) + dotprod += 2 * $real(dot(A.data[i, j], B.data[i, j])) + end + dotprod += dot(A[j, j], B[j, j]) + end + elseif A.uplo == 'L' && B.uplo == 'L' + for j in 1:n + dotprod += dot(A[j, j], B[j, j]) + for i in (j + 1):n + dotprod += 2 * $real(dot(A.data[i, j], B.data[i, j])) + end + end + elseif A.uplo == 'U' && B.uplo == 'L' + for j in 1:n + for i in 1:(j - 1) + dotprod += 2 * $real(dot(A.data[i, j], $trans(B.data[j, i]))) + end + dotprod += dot(A[j, j], B[j, j]) + end + else + for j in 1:n + dotprod += dot(A[j, j], B[j, j]) + for i in (j + 1):n + dotprod += 2 * $real(dot(A.data[i, j], $trans(B.data[j, i]))) + end + end end + return dotprod end end - return dotprod end (-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo))