From 2f04b9ea1ea914c1c4b3d9a70eebe8970501f2e2 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 25 Sep 2020 12:08:18 +0800 Subject: [PATCH 1/5] add reservoir_array_buffer --- src/utils/device.jl | 37 ++++++++++++++-------------- src/utils/reservoir_array_buffer.jl | 27 ++++++++++++++++++++ src/utils/utils.jl | 1 + test/utils/reservoir_array_buffer.jl | 16 ++++++++++++ test/utils/utils.jl | 1 + 5 files changed, 64 insertions(+), 18 deletions(-) create mode 100644 src/utils/reservoir_array_buffer.jl create mode 100644 test/utils/reservoir_array_buffer.jl diff --git a/src/utils/device.jl b/src/utils/device.jl index 627095b..ca3f6c1 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -13,28 +13,29 @@ send_to_device(::Val{:cpu}, x) = x # cpu(x) is not very efficient! So by defaul send_to_device(::Val{:cpu}, x::CuArray) = adapt(Array, x) send_to_device(::Val{:gpu}, x) = Flux.fmap(a -> adapt(CuArray{Float32}, a), x) -send_to_device( - ::Val{:gpu}, - x::Union{ - SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - Base.ReshapedArray{ - <:Any, - <:Any, - <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - }, - Base.ReshapedArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - SubArray{ + +const KnownArrayVariants = Union{ + SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + Base.ReshapedArray{ + <:Any, + <:Any, + <:SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + }, + Base.ReshapedArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + SubArray{ + <:Any, + <:Any, + <:Base.ReshapedArray{ <:Any, <:Any, - <:Base.ReshapedArray{ - <:Any, - <:Any, - <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - }, + <:SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, }, - ElasticArray, }, -) = CuArray(x) +} + +# https://github.com/JuliaReinforcementLearning/ReinforcementLearningCore.jl/issues/130 +send_to_device(::Val{:cpu}, x::KnownArrayVariants) = Array(x) +send_to_device(::Val{:gpu}, x::Union{KnownArrayVariants, ElasticArray}) = CuArray(x) """ device(model) diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl new file mode 100644 index 0000000..81dd8d1 --- /dev/null +++ b/src/utils/reservoir_array_buffer.jl @@ -0,0 +1,27 @@ +export ReservoirArrayBuffer + +using Random +using ElasticArrays +using MacroTools:@forward + +mutable struct ReservoirArrayBuffer{T, N, B<:ElasticArray{T, N}, R<:AbstractRNG} <: AbstractArray{T, N} + buffer::B + capacity::Int + rng::R +end + +ReservoirArrayBuffer{T}(dims::Int...;rng=Random.GLOBAL_RNG) where {T} = ReservoirArrayBuffer(ElasticArray{T}(undef, dims[1:end-1]..., 0), dims[end], rng) + +@forward ReservoirArrayBuffer.buffer Base.size, Base.getindex, Base.length, Base.sizeof, Base.IndexStyle + +# TODO: rename all push! to append! + +function Base.push!(b::ReservoirArrayBuffer{T, N}, x) where {T, N} + if size(b, N) < b.capacity + append!(b.buffer, x) + else + i = rand(b.rng, 1:size(b, N)) + stride = b.buffer.kernel_length.divisor + b.buffer[(stride*(i-1)+1): stride*i] .= x + end +end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index be1556d..9ba0032 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,5 +1,6 @@ include("printing.jl") include("base.jl") include("circular_array_buffer.jl") +include("reservoir_array_buffer.jl") include("device.jl") include("sum_tree.jl") diff --git a/test/utils/reservoir_array_buffer.jl b/test/utils/reservoir_array_buffer.jl new file mode 100644 index 0000000..c6958fe --- /dev/null +++ b/test/utils/reservoir_array_buffer.jl @@ -0,0 +1,16 @@ +@testset "ReservoirArrayBuffer" begin + b = ReservoirArrayBuffer{Int}(3, 2) + @assert size(b) == (3, 0) + + push!(b, [1,1,1]) + @assert size(b) == (3, 1) + @test all(b .== [1;1;1]) + + push!(b, [2,2,2]) + @assert size(b) == (3, 2) + @test all(b .== [1 2;1 2;1 2]) + + push!(b, [0,0,0]) + + @test size(b) == (3, 2) +end \ No newline at end of file diff --git a/test/utils/utils.jl b/test/utils/utils.jl index 7a7715c..bca9bdb 100644 --- a/test/utils/utils.jl +++ b/test/utils/utils.jl @@ -1,3 +1,4 @@ include("base.jl") include("circular_array_buffer.jl") +include("reservoir_array_buffer.jl") include("device.jl") From ff9482dbbbeb12f69ec7c538b4da08a3935cd9d5 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 25 Sep 2020 22:02:33 +0800 Subject: [PATCH 2/5] bugfix with ReservoirArrayBuffer --- src/utils/reservoir_array_buffer.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl index 81dd8d1..0ecad17 100644 --- a/src/utils/reservoir_array_buffer.jl +++ b/src/utils/reservoir_array_buffer.jl @@ -6,22 +6,26 @@ using MacroTools:@forward mutable struct ReservoirArrayBuffer{T, N, B<:ElasticArray{T, N}, R<:AbstractRNG} <: AbstractArray{T, N} buffer::B + n::Int capacity::Int rng::R end -ReservoirArrayBuffer{T}(dims::Int...;rng=Random.GLOBAL_RNG) where {T} = ReservoirArrayBuffer(ElasticArray{T}(undef, dims[1:end-1]..., 0), dims[end], rng) +ReservoirArrayBuffer{T}(dims::Int...;rng=Random.GLOBAL_RNG) where {T} = ReservoirArrayBuffer(ElasticArray{T}(undef, dims[1:end-1]..., 0), 0, dims[end], rng) @forward ReservoirArrayBuffer.buffer Base.size, Base.getindex, Base.length, Base.sizeof, Base.IndexStyle # TODO: rename all push! to append! function Base.push!(b::ReservoirArrayBuffer{T, N}, x) where {T, N} + b.n += 1 if size(b, N) < b.capacity append!(b.buffer, x) else - i = rand(b.rng, 1:size(b, N)) - stride = b.buffer.kernel_length.divisor - b.buffer[(stride*(i-1)+1): stride*i] .= x + i = rand(b.rng, 1:b.n) + if i <= b.capacity + stride = b.buffer.kernel_length.divisor + b.buffer[(stride*(i-1)+1): stride*i] .= x + end end end From d2b23f40c77b2a9870de105c7fba487ea4aca612 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 25 Sep 2020 22:09:44 +0800 Subject: [PATCH 3/5] minor improvement --- src/utils/reservoir_array_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl index 0ecad17..37c831e 100644 --- a/src/utils/reservoir_array_buffer.jl +++ b/src/utils/reservoir_array_buffer.jl @@ -19,7 +19,7 @@ ReservoirArrayBuffer{T}(dims::Int...;rng=Random.GLOBAL_RNG) where {T} = Reservoi function Base.push!(b::ReservoirArrayBuffer{T, N}, x) where {T, N} b.n += 1 - if size(b, N) < b.capacity + if b.n <= b.capacity append!(b.buffer, x) else i = rand(b.rng, 1:b.n) From 9cb873c35b2e9a222315209653abdc904bbcc92d Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 25 Sep 2020 22:18:24 +0800 Subject: [PATCH 4/5] minor improvement --- src/utils/reservoir_array_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl index 37c831e..554b5c3 100644 --- a/src/utils/reservoir_array_buffer.jl +++ b/src/utils/reservoir_array_buffer.jl @@ -25,7 +25,7 @@ function Base.push!(b::ReservoirArrayBuffer{T, N}, x) where {T, N} i = rand(b.rng, 1:b.n) if i <= b.capacity stride = b.buffer.kernel_length.divisor - b.buffer[(stride*(i-1)+1): stride*i] .= x + @inbounds b.buffer.data[(stride*(i-1)+1): stride*i] .= x end end end From 53d995b99506e9f38861c0ef5ae02b1b874fd3d1 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 25 Sep 2020 22:19:24 +0800 Subject: [PATCH 5/5] minor improvement --- src/utils/reservoir_array_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl index 554b5c3..717ee51 100644 --- a/src/utils/reservoir_array_buffer.jl +++ b/src/utils/reservoir_array_buffer.jl @@ -25,7 +25,7 @@ function Base.push!(b::ReservoirArrayBuffer{T, N}, x) where {T, N} i = rand(b.rng, 1:b.n) if i <= b.capacity stride = b.buffer.kernel_length.divisor - @inbounds b.buffer.data[(stride*(i-1)+1): stride*i] .= x + b.buffer.data[(stride*(i-1)+1): stride*i] .= x end end end