From abe1129fab6fe1c963fcef38e9cf950defa6c198 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 24 Jan 2017 14:33:47 +0000 Subject: [PATCH 01/11] type-stable inner loop for sqrtm As suggested by Ralph_Smith on [discourse](https://discourse.julialang.org/t/review-schur-pade-matrix-powers-speedup/1650/6) On my machine: speedup x15 --- base/linalg/triangular.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 61c8ffe34dc35..8c66d1ed93a2d 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1867,6 +1867,16 @@ function logm{T<:Union{Float64,Complex{Float64}}}(A0::UpperTriangular{T}) end logm(A::LowerTriangular) = logm(A.').' +function floop(x,R,i::Int,j::Int) + r = x + @inbounds begin + @simd for k = i+1:j-1 + r -= R[i,k]*R[k,j] + end + end + r +end + function sqrtm{T}(A::UpperTriangular{T}) n = checksquare(A) realmatrix = false @@ -1888,10 +1898,7 @@ function sqrtm{T}(A::UpperTriangular{T}) for j = 1:n R[j,j] = realmatrix?sqrt(A[j,j]):sqrt(complex(A[j,j])) for i = j-1:-1:1 - r = A[i,j] - for k = i+1:j-1 - r -= R[i,k]*R[k,j] - end + r = floop(A[i,j],R,i,j) r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) end end @@ -1900,14 +1907,10 @@ end function sqrtm{T}(A::UnitUpperTriangular{T}) n = checksquare(A) TT = typeof(sqrt(zero(T))) - R = zeros(TT, n, n) + R = eye(TT, n, n) for j = 1:n - R[j,j] = one(T) for i = j-1:-1:1 - r = A[i,j] - for k = i+1:j-1 - r -= R[i,k]*R[k,j] - end + r = floop(A[i,j],R,i,j) r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) end end From a18f8daa1a21dde14d96240f88d07ed841cf7afa Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 25 Jan 2017 10:13:57 +0000 Subject: [PATCH 02/11] dispatch sqrtm on real-or-not bool As suggested by @stevengj --- base/linalg/triangular.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 8c66d1ed93a2d..aa14e5b26f463 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1876,8 +1876,7 @@ function floop(x,R,i::Int,j::Int) end r end - -function sqrtm{T}(A::UpperTriangular{T}) +function sqrtm(A::UpperTriangular) n = checksquare(A) realmatrix = false if isreal(A) @@ -1889,14 +1888,18 @@ function sqrtm{T}(A::UpperTriangular{T}) end end end + sqrtm(A::UpperTriangular,Val{realmatrix}) +end +function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) if realmatrix TT = typeof(sqrt(zero(T))) else TT = typeof(sqrt(complex(-one(T)))) end + n = checksquare(A) R = zeros(TT, n, n) for j = 1:n - R[j,j] = realmatrix?sqrt(A[j,j]):sqrt(complex(A[j,j])) + R[j,j] = realmatrix ? sqrt(A[j,j]) : sqrt(complex(A[j,j])) for i = j-1:-1:1 r = floop(A[i,j],R,i,j) r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) From a538250d540c535e940bbcfb6246a11c6f2892ac Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 25 Jan 2017 10:30:22 +0000 Subject: [PATCH 03/11] remove call site type assertion --- base/linalg/triangular.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index aa14e5b26f463..90b96f23acc2b 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1888,7 +1888,7 @@ function sqrtm(A::UpperTriangular) end end end - sqrtm(A::UpperTriangular,Val{realmatrix}) + sqrtm(A,Val{realmatrix}) end function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) if realmatrix From 5718a1c9e78d1e451c95feadcbac2e31a42350b4 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 25 Jan 2017 18:02:00 +0000 Subject: [PATCH 04/11] more type stability --- base/linalg/triangular.jl | 44 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 90b96f23acc2b..bfabbb6ea7852 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1867,21 +1867,11 @@ function logm{T<:Union{Float64,Complex{Float64}}}(A0::UpperTriangular{T}) end logm(A::LowerTriangular) = logm(A.').' -function floop(x,R,i::Int,j::Int) - r = x - @inbounds begin - @simd for k = i+1:j-1 - r -= R[i,k]*R[k,j] - end - end - r -end function sqrtm(A::UpperTriangular) - n = checksquare(A) realmatrix = false if isreal(A) realmatrix = true - for i = 1:n + for i = 1:Base.LinAlg.checksquare(A) if real(A[i,i]) < 0 realmatrix = false break @@ -1894,15 +1884,20 @@ function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) if realmatrix TT = typeof(sqrt(zero(T))) else - TT = typeof(sqrt(complex(-one(T)))) + TT = typeof(sqrt(complex(zero(T)))) end - n = checksquare(A) + n = Base.LinAlg.checksquare(A) R = zeros(TT, n, n) - for j = 1:n - R[j,j] = realmatrix ? sqrt(A[j,j]) : sqrt(complex(A[j,j])) - for i = j-1:-1:1 - r = floop(A[i,j],R,i,j) - r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) + @inbounds begin + for j = 1:n + R[j,j] = realmatrix ? sqrt(A[j,j]) : sqrt(complex(A[j,j])) + for i = j-1:-1:1 + r = A[i,j] + zero(TT) + @simd for k = i+1:j-1 + r -= R[i,k]*R[k,j] + end + r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) + end end end return UpperTriangular(R) @@ -1911,10 +1906,15 @@ function sqrtm{T}(A::UnitUpperTriangular{T}) n = checksquare(A) TT = typeof(sqrt(zero(T))) R = eye(TT, n, n) - for j = 1:n - for i = j-1:-1:1 - r = floop(A[i,j],R,i,j) - r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) + @inbounds begin + for j = 1:n + for i = j-1:-1:1 + r = A[i,j] + zero(TT) + @simd for k = i+1:j-1 + r -= R[i,k]*R[k,j] + end + r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) + end end end return UnitUpperTriangular(R) From 60fcb28d0c83b1bd123162c9a4e7b68dc7d37e34 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 26 Jan 2017 17:42:45 +0000 Subject: [PATCH 05/11] casting changes replace `TT` by `t` for the type of the sqrt of a variable of type `T` introduce `tt` as the type of the square of a variable of type `t` N.B. `tt` is not always the same as `T`, it could be `Complex{T}` In the `UnitUpperTriangular` case, some of the complexity should fall away; TODO? --- base/linalg/triangular.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index bfabbb6ea7852..d53acc8b71544 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1871,7 +1871,7 @@ function sqrtm(A::UpperTriangular) realmatrix = false if isreal(A) realmatrix = true - for i = 1:Base.LinAlg.checksquare(A) + for i = 1:checksquare(A) if real(A[i,i]) < 0 realmatrix = false break @@ -1881,18 +1881,15 @@ function sqrtm(A::UpperTriangular) sqrtm(A,Val{realmatrix}) end function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) - if realmatrix - TT = typeof(sqrt(zero(T))) - else - TT = typeof(sqrt(complex(zero(T)))) - end - n = Base.LinAlg.checksquare(A) - R = zeros(TT, n, n) + n = checksquare(A) + t = realmatrix ? typeof(sqrt(zero(T))) : typeof(sqrt(complex(zero(T)))) + R = zeros(t, n, n) + tt = typeof(zero(t)*zero(t)) @inbounds begin for j = 1:n R[j,j] = realmatrix ? sqrt(A[j,j]) : sqrt(complex(A[j,j])) for i = j-1:-1:1 - r = A[i,j] + zero(TT) + r::tt = A[i,j] @simd for k = i+1:j-1 r -= R[i,k]*R[k,j] end @@ -1904,16 +1901,18 @@ function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) end function sqrtm{T}(A::UnitUpperTriangular{T}) n = checksquare(A) - TT = typeof(sqrt(zero(T))) + t = typeof(sqrt(zero(T))) R = eye(TT, n, n) + tt = typeof(zero(t)*zero(t)) + one = R[1,1] @inbounds begin for j = 1:n for i = j-1:-1:1 - r = A[i,j] + zero(TT) + r::tt = A[i,j] @simd for k = i+1:j-1 r -= R[i,k]*R[k,j] end - r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) + r==0 || (R[i,j] = r/one) end end end From ce830fe90dadf3597189bf09d1426a011645fd49 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 26 Jan 2017 19:49:27 +0000 Subject: [PATCH 06/11] one -> two correct previous change, thx tkelman --- base/linalg/triangular.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index d53acc8b71544..7d4fd07c92a5b 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1904,7 +1904,7 @@ function sqrtm{T}(A::UnitUpperTriangular{T}) t = typeof(sqrt(zero(T))) R = eye(TT, n, n) tt = typeof(zero(t)*zero(t)) - one = R[1,1] + two = 2*R[1,1] @inbounds begin for j = 1:n for i = j-1:-1:1 @@ -1912,7 +1912,7 @@ function sqrtm{T}(A::UnitUpperTriangular{T}) @simd for k = i+1:j-1 r -= R[i,k]*R[k,j] end - r==0 || (R[i,j] = r/one) + r==0 || (R[i,j] = r/two) end end end From a7e7ce8694cff34ff74f3137e849ba2844f95693 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 26 Jan 2017 21:02:17 +0000 Subject: [PATCH 07/11] div 2 -> mult 1/2 --- base/linalg/triangular.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 7d4fd07c92a5b..bfc8f1314ac28 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1904,7 +1904,7 @@ function sqrtm{T}(A::UnitUpperTriangular{T}) t = typeof(sqrt(zero(T))) R = eye(TT, n, n) tt = typeof(zero(t)*zero(t)) - two = 2*R[1,1] + half = 1/(2*R[1,1]) @inbounds begin for j = 1:n for i = j-1:-1:1 @@ -1912,7 +1912,7 @@ function sqrtm{T}(A::UnitUpperTriangular{T}) @simd for k = i+1:j-1 r -= R[i,k]*R[k,j] end - r==0 || (R[i,j] = r/two) + r==0 || (R[i,j] = half*r) end end end From 075fd38a0579dc017291fde070c3c7e46d9faa78 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 26 Jan 2017 21:58:51 +0000 Subject: [PATCH 08/11] typo --- base/linalg/triangular.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index bfc8f1314ac28..94574bcd43e70 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1902,7 +1902,7 @@ end function sqrtm{T}(A::UnitUpperTriangular{T}) n = checksquare(A) t = typeof(sqrt(zero(T))) - R = eye(TT, n, n) + R = eye(t, n, n) tt = typeof(zero(t)*zero(t)) half = 1/(2*R[1,1]) @inbounds begin From 03ef1b47dcb2a09916091cbb124135f232db739f Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 27 Jan 2017 18:27:10 +0000 Subject: [PATCH 09/11] algebraicness, unwrapping --- base/linalg/triangular.jl | 40 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 94574bcd43e70..46af2ed0f4973 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1880,40 +1880,40 @@ function sqrtm(A::UpperTriangular) end sqrtm(A,Val{realmatrix}) end +# solve the sylvester equation a*x + x*b + c for x when a,b,x are commutative numbers. PR#20214 +sylvester(a::Union{Real,Complex},b::Union{Real,Complex},c::Union{Real,Complex}) = -c / (a + b) function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) - n = checksquare(A) + B = A.data + n = checksquare(B) t = realmatrix ? typeof(sqrt(zero(T))) : typeof(sqrt(complex(zero(T)))) R = zeros(t, n, n) tt = typeof(zero(t)*zero(t)) - @inbounds begin - for j = 1:n - R[j,j] = realmatrix ? sqrt(A[j,j]) : sqrt(complex(A[j,j])) - for i = j-1:-1:1 - r::tt = A[i,j] - @simd for k = i+1:j-1 - r -= R[i,k]*R[k,j] - end - r==0 || (R[i,j] = r / (R[i,i] + R[j,j])) + @inbounds for j = 1:n + R[j,j] = realmatrix ? sqrt(B[j,j]) : sqrt(complex(B[j,j])) + for i = j-1:-1:1 + r::tt = B[i,j] + @simd for k = i+1:j-1 + r -= R[i,k]*R[k,j] end + r==0 || (R[i,j] = sylvester(R[i,i],R[j,j],-r)) end end return UpperTriangular(R) end function sqrtm{T}(A::UnitUpperTriangular{T}) - n = checksquare(A) + B = A.data + n = checksquare(B) t = typeof(sqrt(zero(T))) R = eye(t, n, n) tt = typeof(zero(t)*zero(t)) - half = 1/(2*R[1,1]) - @inbounds begin - for j = 1:n - for i = j-1:-1:1 - r::tt = A[i,j] - @simd for k = i+1:j-1 - r -= R[i,k]*R[k,j] - end - r==0 || (R[i,j] = half*r) + half = inv(R[1,1]+R[1,1]) # for general, algebraic cases. PR#20214 + @inbounds for j = 1:n + for i = j-1:-1:1 + r::tt = B[i,j] + @simd for k = i+1:j-1 + r -= R[i,k]*R[k,j] end + r==0 || (R[i,j] = half*r) end end return UnitUpperTriangular(R) From 7b990f1cc40018e556b4e39bc754fbcc9a9fe9a4 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 31 Jan 2017 11:08:53 +0000 Subject: [PATCH 10/11] sylvester for numbers [ci skip] moved from triangular.jl --- base/linalg/dense.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/base/linalg/dense.jl b/base/linalg/dense.jl index 757de7c8d51cb..3c02b3395e0ae 100644 --- a/base/linalg/dense.jl +++ b/base/linalg/dense.jl @@ -874,6 +874,8 @@ function sylvester{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T},C::Stri end sylvester{T<:Integer}(A::StridedMatrix{T},B::StridedMatrix{T},C::StridedMatrix{T}) = sylvester(float(A), float(B), float(C)) +sylvester(a::Union{Real,Complex},b::Union{Real,Complex},c::Union{Real,Complex}) = -c / (a + b) + # AX + XA' + C = 0 """ From e3cf0a9f3b3397ef05f6e9798aca97f5255e84a3 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 31 Jan 2017 11:09:05 +0000 Subject: [PATCH 11/11] move sylvester --- base/linalg/triangular.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 46af2ed0f4973..71a238fbdde2d 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1880,8 +1880,6 @@ function sqrtm(A::UpperTriangular) end sqrtm(A,Val{realmatrix}) end -# solve the sylvester equation a*x + x*b + c for x when a,b,x are commutative numbers. PR#20214 -sylvester(a::Union{Real,Complex},b::Union{Real,Complex},c::Union{Real,Complex}) = -c / (a + b) function sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) B = A.data n = checksquare(B)