From 71e657dc6a506a03259f871f4bdb0734de0dce3d Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 28 Mar 2023 20:14:51 +0100 Subject: [PATCH 1/8] Add Zeros(T, n...) and Ones(T, n...) constructors (#94( (#233) * Add Zeros(T, n...) and Ones(T, n...) constructors (#94( * increase coverage --- src/FillArrays.jl | 1 + src/fillalgebra.jl | 1 - test/runtests.jl | 8 +++++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index bf090490..f07d2a56 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -262,6 +262,7 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one)) @inline $Typ{T,N}(A::AbstractArray{V,N}) where{T,V,N} = $Typ{T,N}(size(A)) @inline $Typ{T}(A::AbstractArray) where{T} = $Typ{T}(size(A)) @inline $Typ(A::AbstractArray) = $Typ{eltype(A)}(A) + @inline $Typ(::Type{T}, m...) where T = $Typ{T}(m...) @inline axes(Z::$Typ) = Z.axes @inline size(Z::$Typ) = length.(Z.axes) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 2dec1b61..800e803a 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -86,7 +86,6 @@ end *(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b) *(a::AbstractMatrix, b::ZerosVector) = mult_zeros(a, b) *(a::AbstractMatrix, b::ZerosMatrix) = mult_zeros(a, b) -*(a::ZerosVector, b::AbstractVector) = mult_zeros(a, b) *(a::ZerosMatrix, b::AbstractVector) = mult_zeros(a, b) *(a::AbstractVector, b::ZerosMatrix) = mult_zeros(a, b) diff --git a/test/runtests.jl b/test/runtests.jl index 81ba61c5..35b647e2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,7 @@ include("infinitearrays.jl") for T in (Int, Float64) Z = $Typ{T}(5) + @test $Typ(T, 5) ≡ Z @test eltype(Z) == T @test Array(Z) == $funcs(T,5) @test Array{T}(Z) == $funcs(T,5) @@ -34,6 +35,7 @@ include("infinitearrays.jl") @test $Typ(2ones(T,5)) == Z Z = $Typ{T}(5, 5) + @test $Typ(T, 5, 5) ≡ Z @test eltype(Z) == T @test Array(Z) == $funcs(T,5,5) @test Array{T}(Z) == $funcs(T,5,5) @@ -508,9 +510,9 @@ end @test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either @testset "Check multiplication by Adjoint vectors works as expected." begin - @test randn(4, 3)' * Zeros(4) === Zeros(3) - @test randn(4)' * Zeros(4) === zero(Float64) - @test [1, 2, 3]' * Zeros{Int}(3) === zero(Int) + @test randn(4, 3)' * Zeros(4) ≡ Zeros(3) + @test randn(4)' * Zeros(4) ≡ transpose(randn(4)) * Zeros(4) ≡ zero(Float64) + @test [1, 2, 3]' * Zeros{Int}(3) ≡ zero(Int) @test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0) @test_throws DimensionMismatch randn(4)' * Zeros(3) @test Zeros(5)' * randn(5,3) ≡ Zeros(5)'*Zeros(5,3) ≡ Zeros(5)'*Ones(5,3) ≡ Zeros(3)' From 4cfd7c92953e12fe3286d36e639931c513ce3a66 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 28 Mar 2023 20:15:57 +0100 Subject: [PATCH 2/8] Update README.md --- README.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 997b3fec..86d7137e 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,8 @@ as well as identity matrices. This package exports the following types: The primary purpose of this package is to present a unified way of constructing -matrices. For example, to construct a 5-by-5 `CLArray` of all zeros, one would use -```julia -julia> CLArray(Zeros(5,5)) -``` -Because `Zeros` is lazy, this can be accomplished on the GPU with no memory transfer. -Similarly, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use +matrices. +For example, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use ```julia julia> BandedMatrix(Zeros(5,5), (1, 2)) ``` From 1076df16f628df8f3bb1743000d371e166a0f7a0 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 28 Mar 2023 21:27:43 +0100 Subject: [PATCH 3/8] Move over OneElement from Zygote --- src/FillArrays.jl | 4 +++- src/oneelement.jl | 16 ++++++++++++++++ test/runtests.jl | 4 ++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 src/oneelement.jl diff --git a/src/FillArrays.jl b/src/FillArrays.jl index f07d2a56..d6fd54d6 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -18,7 +18,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape import Statistics: mean, std, var, cov, cor -export Zeros, Ones, Fill, Eye, Trues, Falses +export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement import Base: oneto @@ -718,4 +718,6 @@ Base.@propagate_inbounds function view(A::AbstractFill{<:Any,N}, I::Vararg{Real, fillsimilar(A) end +include("oneelement.jl") + end # module diff --git a/src/oneelement.jl b/src/oneelement.jl new file mode 100644 index 00000000..e69224cb --- /dev/null +++ b/src/oneelement.jl @@ -0,0 +1,16 @@ +""" + OneElement(val, ind, axes) <: AbstractArray +Extremely simple `struct` used for the gradient of scalar `getindex`. +""" +struct OneElement{T,N,I,A} <: AbstractArray{T,N} + val::T + ind::I + axes::A + OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes) +end + +# OneElement(val, inds::Int...) = OneElement(val, inds) + +Base.size(A::OneElement) = map(length, A.axes) +Base.axes(A::OneElement) = A.axes +Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 35b647e2..139f2170 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1471,3 +1471,7 @@ end @test cor(Fill(3, 4, 5)) ≈ cor(fill(3, 4, 5)) nans=true @test cor(Fill(3, 4, 5), dims=2) ≈ cor(fill(3, 4, 5), dims=2) nans=true end + +@testset "OneElement" begin + e₁ = OneElement(5) +end \ No newline at end of file From a319e1a0edd96442eec8325ba32e1b84d8ebee6e Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 28 Mar 2023 21:37:27 +0100 Subject: [PATCH 4/8] Add tests --- src/oneelement.jl | 15 +++++++++++++-- test/runtests.jl | 10 +++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index e69224cb..7bf462d3 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -9,8 +9,19 @@ struct OneElement{T,N,I,A} <: AbstractArray{T,N} OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes) end -# OneElement(val, inds::Int...) = OneElement(val, inds) +OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz)) +OneElement(val, inds::Int, sz::Int) where N = OneElement(val, (inds,), (sz,)) +OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz) +OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz) Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes -Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) \ No newline at end of file +Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) + +Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) = + o.ind == (k,j) ? s : Base.replace_with_centered_mark(s) + +function Base.setindex(A::Zeros{T,N}, v, kj::Vararg{Int,N}) where {T,N} + @boundscheck checkbounds(A, kj...) + OneElement(convert(T, v), kj, axes(A)) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 139f2170..f429caa8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1473,5 +1473,13 @@ end end @testset "OneElement" begin - e₁ = OneElement(5) + e₁ = OneElement(2, 5) + @test e₁ == [0,1,0,0,0] + + e₁ = OneElement{Float64}(2, 5) + @test e₁ == [0,1,0,0,0] + + @test Base.setindex(Zeros(5), 2, 2) ≡ OneElement(2.0, 2, 5) + @test Base.setindex(Zeros(5,3), 2, 2, 3) ≡ OneElement(2.0, (2,3), (5,3)) + @test_throws BoundsError Base.setindex(Zeros(5), 2, 6) end \ No newline at end of file From 652b41e9d0e1edb074206f68ded3b1cc96991a0c Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 28 Mar 2023 21:43:23 +0100 Subject: [PATCH 5/8] Update oneelement.jl --- src/oneelement.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index 7bf462d3..9af6d4d2 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -10,7 +10,7 @@ struct OneElement{T,N,I,A} <: AbstractArray{T,N} end OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz)) -OneElement(val, inds::Int, sz::Int) where N = OneElement(val, (inds,), (sz,)) +OneElement(val, inds::Int, sz::Int) = OneElement(val, (inds,), (sz,)) OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz) OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz) From f73e2d72dec0f65533594d3437750c2c4a5faef5 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 29 Mar 2023 11:56:03 +0100 Subject: [PATCH 6/8] add tests --- src/oneelement.jl | 7 ++++++- test/runtests.jl | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index 9af6d4d2..ac3e7683 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -12,11 +12,16 @@ end OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz)) OneElement(val, inds::Int, sz::Int) = OneElement(val, (inds,), (sz,)) OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz) +OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = OneElement(convert(T,val), inds, oneto.(sz)) +OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,)) OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz) Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes -Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) +function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N} + @boundscheck checkbounds(A, kj...) + ifelse(kj == A.ind, A.val, zero(T)) +end Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) = o.ind == (k,j) ? s : Base.replace_with_centered_mark(s) diff --git a/test/runtests.jl b/test/runtests.jl index a85950fd..6532076a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1493,10 +1493,19 @@ end @testset "OneElement" begin e₁ = OneElement(2, 5) @test e₁ == [0,1,0,0,0] + @test_throws BoundsError e₁[6] e₁ = OneElement{Float64}(2, 5) @test e₁ == [0,1,0,0,0] + v = OneElement{Float64}(2, 3, 4) + @test v == [0,0,2,0] + + V = OneElement(2, (2,3), (3,4)) + @test V == [0 0 0 0; 0 0 2 0; 0 0 0 0] + + @test stringmime("text/plain", V) == "3×4 OneElement{$Int, 2, Tuple{$Int, $Int}, Tuple{OneTo{$Int}, OneTo{$Int}}}:\n ⋅ ⋅ ⋅ ⋅\n ⋅ ⋅ 2 ⋅\n ⋅ ⋅ ⋅ ⋅" + @test Base.setindex(Zeros(5), 2, 2) ≡ OneElement(2.0, 2, 5) @test Base.setindex(Zeros(5,3), 2, 2, 3) ≡ OneElement(2.0, (2,3), (5,3)) @test_throws BoundsError Base.setindex(Zeros(5), 2, 6) From 897ff21e717219aec6ae0e7b6de42d65ef1f4a42 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 29 Mar 2023 15:04:59 +0100 Subject: [PATCH 7/8] Update runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6532076a..62cc433e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1504,7 +1504,7 @@ end V = OneElement(2, (2,3), (3,4)) @test V == [0 0 0 0; 0 0 2 0; 0 0 0 0] - @test stringmime("text/plain", V) == "3×4 OneElement{$Int, 2, Tuple{$Int, $Int}, Tuple{OneTo{$Int}, OneTo{$Int}}}:\n ⋅ ⋅ ⋅ ⋅\n ⋅ ⋅ 2 ⋅\n ⋅ ⋅ ⋅ ⋅" + @test stringmime("text/plain", V) == "3×4 OneElement{$Int, 2, Tuple{$Int, $Int}, Tuple{Base.OneTo{$Int}, Base.OneTo{$Int}}}:\n ⋅ ⋅ ⋅ ⋅\n ⋅ ⋅ 2 ⋅\n ⋅ ⋅ ⋅ ⋅" @test Base.setindex(Zeros(5), 2, 2) ≡ OneElement(2.0, 2, 5) @test Base.setindex(Zeros(5,3), 2, 2, 3) ≡ OneElement(2.0, (2,3), (5,3)) From c522b6161d8b4146c055f7f9dd6f0ead4bd64de6 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 29 Mar 2023 15:15:59 +0100 Subject: [PATCH 8/8] add docs --- src/oneelement.jl | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index ac3e7683..abd37f5e 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -1,6 +1,9 @@ """ - OneElement(val, ind, axes) <: AbstractArray -Extremely simple `struct` used for the gradient of scalar `getindex`. + OneElement(val, ind, axesorsize) <: AbstractArray + +Represents an array with the specified axes (if its a tuple of `AbstractUnitRange`s) +or size (if its a tuple of `Integer`s), with a single entry set to `val` and all others equal to zero, +specified by `ind``. """ struct OneElement{T,N,I,A} <: AbstractArray{T,N} val::T @@ -10,10 +13,26 @@ struct OneElement{T,N,I,A} <: AbstractArray{T,N} end OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz)) -OneElement(val, inds::Int, sz::Int) = OneElement(val, (inds,), (sz,)) +""" + OneElement(val, ind::Int, n::Int) + +Creates a length `n` vector where the `ind` entry is equal to `val`, and all other entries are zero. +""" +OneElement(val, ind::Int, len::Int) = OneElement(val, (ind,), (len,)) +""" + OneElement(ind::Int, n::Int) + +Creates a length `n` vector where the `ind` entry is equal to `1`, and all other entries are zero. +""" OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz) OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = OneElement(convert(T,val), inds, oneto.(sz)) OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,)) + +""" + OneElement{T}(val, ind::Int, n::Int) + +Creates a length `n` vector where the `ind` entry is equal to `one(T)`, and all other entries are zero. +""" OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz) Base.size(A::OneElement) = map(length, A.axes)