From 1c29d4ad504e4116a7e0d3f5a7b133069e76b1cd Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Mon, 15 Mar 2021 16:10:19 +0800 Subject: [PATCH] [WIP] ProductedArrays --- Project.toml | 3 +++ src/MappedArrays.jl | 4 ++++ src/ProductedArrays.jl | 24 ++++++++++++++++++++++++ 3 files changed, 31 insertions(+) create mode 100644 src/ProductedArrays.jl diff --git a/Project.toml b/Project.toml index 10826fa..7918169 100644 --- a/Project.toml +++ b/Project.toml @@ -2,6 +2,9 @@ name = "MappedArrays" uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" version = "0.3.0" +[deps] +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" + [compat] FixedPointNumbers = "0.6.1, 0.7, 0.8" julia = "1" diff --git a/src/MappedArrays.jl b/src/MappedArrays.jl index 3e7f28f..ea08e3c 100644 --- a/src/MappedArrays.jl +++ b/src/MappedArrays.jl @@ -4,6 +4,10 @@ using Base: @propagate_inbounds export AbstractMappedArray, MappedArray, ReadonlyMappedArray, mappedarray, of_eltype +using Reexport +include("ProductedArrays.jl") +@reexport using .ProductedArrays + abstract type AbstractMappedArray{T,N} <: AbstractArray{T,N} end abstract type AbstractMultiMappedArray{T,N} <: AbstractMappedArray{T,N} end diff --git a/src/ProductedArrays.jl b/src/ProductedArrays.jl new file mode 100644 index 0000000..6787e86 --- /dev/null +++ b/src/ProductedArrays.jl @@ -0,0 +1,24 @@ +module ProductedArrays + export ProductedArray + + struct ProductedArray{T, N, AAs<:Tuple{Vararg{AbstractArray}}} <: AbstractArray{T, N} + data::AAs + end + function ProductedArray(data...) + ProductedArray{typeof(map(first, data)), mapreduce(ndims, +, data), typeof(data)}(data) + end + + @inline Base.size(A::ProductedArray) = mapreduce(size, (i,j)->(i...,j...), A.data) + + Base.@propagate_inbounds function Base.getindex(A::ProductedArray{T, N}, inds::Vararg{Int, N}) where {T, N} + map((x, i)->x[i...], A.data, _split_indices(A, inds)) + end + + # TODO: this fails to inline and thus gives about 1.5ns overhead to getindex + @inline function _split_indices(A::ProductedArray{T, N}, inds::NTuple{N, Int}) where {T, N} + # TODO: this line is repeatedly computed + pos = (firstindex(A.data)-1, accumulate(+, map(ndims, A.data))...) + + return ntuple(i->inds[pos[i]+1:pos[i+1]], length(pos)-1) + end +end