diff --git a/src/dense.jl b/src/dense.jl index ccf6bc80..234391e5 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -972,7 +972,8 @@ sqrt(::AbstractMatrix) function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}} if checksquare(A) == 0 return copy(float(A)) - elseif isdiag(A) + elseif isdiag(A) && (T <: Complex || all(x -> x ≥ zero(x), diagview(A))) + # Real Diagonal sqrt requires each diagonal element to be positive return applydiagonal(sqrt, A) elseif ishermitian(A) sqrtHermA = sqrt(Hermitian(A)) diff --git a/test/dense.jl b/test/dense.jl index 1f953ea9..1d4b330c 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -984,6 +984,12 @@ end @testset "sqrt for diagonal" begin A = diagm(0 => [1, 2, 3]) @test sqrt(A)^2 ≈ A + + A = diagm(0 => [1.0, -1.0]) + @test sqrt(A) == diagm(0 => [1.0, 1.0im]) + @test sqrt(A)^2 ≈ A + B = im*A + @test sqrt(B)^2 ≈ B end @testset "issue #40141" begin