Skip to content

Commit a318d7d

Browse files
committed
more forgiving signatures
1 parent d78c103 commit a318d7d

File tree

2 files changed

+13
-18
lines changed

2 files changed

+13
-18
lines changed

src/projection.jl

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -367,21 +367,16 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
367367
dy = eltype(dx) <: Number ? vec(dx) : transpose(dx)
368368
return transpose(project.parent(dy))
369369
end
370-
function (project::ProjectTo{Transpose})(
371-
dx::Tangent{<:Transpose, <:NamedTuple{(:parent,), <:Tuple{AbstractVector}}},
372-
)
373-
return Transpose(project.parent(dx.parent))
370+
function (project::ProjectTo{Transpose})(dx::Tangent) # structural => natural
371+
return dx.parent isa Tangent ? dx : Transpose(project.parent(dx.parent))
374372
end
375373

376374
# Diagonal
377375
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
378376
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
379377
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
380-
# structural => natural standardisation, very conservative signature:
381-
function (project::ProjectTo{Diagonal})(
382-
dx::Tangent{<:Diagonal, <:NamedTuple{(:diag,), <:Tuple{AbstractVector}}},
383-
)
384-
return Diagonal(project.diag(dx.diag))
378+
function (project::ProjectTo{Diagonal})(dx::Tangent) # structural => natural
379+
return dx.diag isa Tangent ? dx.diag : Diagonal(project.diag(dx.diag))
385380
end
386381

387382
# Symmetric
@@ -401,12 +396,8 @@ for (SymHerm, chk, fun) in
401396
dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2
402397
return $SymHerm(project.parent(dz), project.uplo)
403398
end
404-
function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm})
405-
if dx.data isa Tangent
406-
return dx
407-
else
408-
return $SymHerm(project.parent(dx.data))
409-
end
399+
function (project::ProjectTo{$SymHerm})(dx::Tangent) # structural => natural
400+
return dx.data isa Tangent ? dx : $SymHerm(project.parent(dx.data), project.uplo)
410401
end
411402
# This is an example of a subspace which is not a subtype,
412403
# not clear how broadly it's worthwhile to try to support this.
@@ -433,8 +424,9 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT
433424
)
434425
return Diagonal(sub_one(dx.diag))
435426
end
436-
# Convert "structural" `Tangent`s to array-like "natural" tangents
437-
(project::ProjectTo{$UL})(dx::Tangent{<:$UL, NamedTuple{(:data,), <:Tuple{AbstractMatrix}}}) = $UL(dx.data)
427+
function (project::ProjectTo{$UL})(dx::Tangent) # structural => natural
428+
return dx.data isa Tangent ? dx : $UL(project.parent(dx.data), project.uplo)
429+
end
438430
end
439431
end
440432

src/tangent_types/abstract_zero.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ Base.sum(z::AbstractZero; dims=:) = z
3333
Base.reshape(z::AbstractZero, size...) = z
3434

3535
# LinearAlgebra
36-
for f in (:adjoint, :transpose, :Adjoint, :Transpose, :Diagonal)
36+
for f in (
37+
:adjoint, :transpose, :Adjoint, :Transpose, :Diagonal,
38+
:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular,
39+
)
3740
@eval LinearAlgebra.$f(z::AbstractZero) = z
3841
end
3942
for f in (:Symmetric, :Hermitian)

0 commit comments

Comments
 (0)