Skip to content

Commit 532aebd

Browse files
Andy Ferrisandyferris
authored andcommitted
Make strides into a generic trait
Returns `nothing` for non-strided arrays, otherwise gives the give strides in memory. Useful as an extensible trait in generic contexts, and simpler to overload for cases of "wrapped" arrays where "stridedness" can be deferred to the parent rather than a complex (and inextensible) method signature.
1 parent ad129a9 commit 532aebd

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

base/abstractarray.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,9 @@ end
511511
"""
512512
strides(A)
513513
514-
Return a tuple of the memory strides in each dimension.
514+
Return a tuple of the memory strides in each dimension, for an `AbstractArray` with a
515+
strided memory layout. For arrays with a non-strided layout (such as sparse arrays), return
516+
`nothing`.
515517
516518
See also: [`stride`](@ref).
517519
@@ -523,7 +525,7 @@ julia> strides(A)
523525
(1, 3, 12)
524526
```
525527
"""
526-
function strides end
528+
strides(::AbstractArray) = nothing
527529

528530
"""
529531
stride(A, k::Integer)
@@ -544,9 +546,13 @@ julia> stride(A,3)
544546
```
545547
"""
546548
function stride(A::AbstractArray, k::Integer)
547-
st = strides(A)
548-
k ndims(A) && return st[k]
549-
return sum(st .* size(A))
549+
str = strides(A)
550+
if str === nothing
551+
return nothing
552+
else
553+
k ndims(A) && return st[k]
554+
return sum(st .* size(A))
555+
end
550556
end
551557

552558
@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)

base/permuteddimsarray.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ Base.pointer(A::PermutedDimsArray, i::Integer) = throw(ArgumentError("pointer(A,
6262

6363
function Base.strides(A::PermutedDimsArray{T,N,perm}) where {T,N,perm}
6464
s = strides(parent(A))
65-
ntuple(d->s[perm[d]], Val(N))
65+
if s === nothing
66+
return nothing
67+
else
68+
return ntuple(d->s[perm[d]], Val(N))
69+
end
6670
end
6771
Base.elsize(::Type{<:PermutedDimsArray{<:Any, <:Any, <:Any, <:Any, P}}) where {P} = Base.elsize(P)
6872

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,25 @@ parent(A::AdjOrTrans) = A.parent
218218
vec(v::TransposeAbsVec{<:Number}) = parent(v)
219219
vec(v::AdjointAbsVec{<:Real}) = parent(v)
220220

221+
# provide strides, but only for eltypes that are directly stored in memory (i.e. unaffected
222+
# by recursive `adjoint` and `transpose`, being `Real` and `Number` respectively)
223+
function Base.strides(a::Union{Adjoint{<:Real, <:AbstractVector}, Transpose{<:Number, <:AbstractVector}})
224+
str = strides(a.parent)
225+
if str === nothing
226+
return nothing
227+
else
228+
return (1, str[1])
229+
end
230+
end
231+
function Base.strides(a::Union{Adjoint{<:Real, <:AbstractMatrix}, Transpose{<:Number, <:AbstractMatrix}})
232+
str = strides(a.parent)
233+
if str === nothing
234+
return nothing
235+
else
236+
return (str[2], str[1])
237+
end
238+
end
239+
221240
### concatenation
222241
# preserve Adjoint/Transpose wrapper around vectors
223242
# to retain the associated semantics post-concatenation

0 commit comments

Comments
 (0)