diff --git a/src/utils/device.jl b/src/utils/device.jl index 627095b..ca3f6c1 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -13,28 +13,29 @@ send_to_device(::Val{:cpu}, x) = x # cpu(x) is not very efficient! So by defaul send_to_device(::Val{:cpu}, x::CuArray) = adapt(Array, x) send_to_device(::Val{:gpu}, x) = Flux.fmap(a -> adapt(CuArray{Float32}, a), x) -send_to_device( - ::Val{:gpu}, - x::Union{ - SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - Base.ReshapedArray{ - <:Any, - <:Any, - <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - }, - Base.ReshapedArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - SubArray{ + +const KnownArrayVariants = Union{ + SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + Base.ReshapedArray{ + <:Any, + <:Any, + <:SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + }, + Base.ReshapedArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + SubArray{ + <:Any, + <:Any, + <:Base.ReshapedArray{ <:Any, <:Any, - <:Base.ReshapedArray{ - <:Any, - <:Any, - <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - }, + <:SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, }, - ElasticArray, }, -) = CuArray(x) +} + +# https://github.com/JuliaReinforcementLearning/ReinforcementLearningCore.jl/issues/130 +send_to_device(::Val{:cpu}, x::KnownArrayVariants) = Array(x) +send_to_device(::Val{:gpu}, x::Union{KnownArrayVariants, ElasticArray}) = CuArray(x) """ device(model) diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl new file mode 100644 index 0000000..717ee51 --- /dev/null +++ b/src/utils/reservoir_array_buffer.jl @@ -0,0 +1,31 @@ +export ReservoirArrayBuffer + +using Random +using ElasticArrays +using MacroTools:@forward + +mutable struct ReservoirArrayBuffer{T, N, B<:ElasticArray{T, N}, R<:AbstractRNG} <: AbstractArray{T, N} + buffer::B + n::Int + capacity::Int + rng::R +end + +ReservoirArrayBuffer{T}(dims::Int...;rng=Random.GLOBAL_RNG) where {T} = ReservoirArrayBuffer(ElasticArray{T}(undef, dims[1:end-1]..., 0), 0, dims[end], rng) + +@forward ReservoirArrayBuffer.buffer Base.size, Base.getindex, Base.length, Base.sizeof, Base.IndexStyle + +# TODO: rename all push! to append! + +function Base.push!(b::ReservoirArrayBuffer{T, N}, x) where {T, N} + b.n += 1 + if b.n <= b.capacity + append!(b.buffer, x) + else + i = rand(b.rng, 1:b.n) + if i <= b.capacity + stride = b.buffer.kernel_length.divisor + b.buffer.data[(stride*(i-1)+1): stride*i] .= x + end + end +end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index be1556d..9ba0032 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,5 +1,6 @@ include("printing.jl") include("base.jl") include("circular_array_buffer.jl") +include("reservoir_array_buffer.jl") include("device.jl") include("sum_tree.jl") diff --git a/test/utils/reservoir_array_buffer.jl b/test/utils/reservoir_array_buffer.jl new file mode 100644 index 0000000..c6958fe --- /dev/null +++ b/test/utils/reservoir_array_buffer.jl @@ -0,0 +1,16 @@ +@testset "ReservoirArrayBuffer" begin + b = ReservoirArrayBuffer{Int}(3, 2) + @assert size(b) == (3, 0) + + push!(b, [1,1,1]) + @assert size(b) == (3, 1) + @test all(b .== [1;1;1]) + + push!(b, [2,2,2]) + @assert size(b) == (3, 2) + @test all(b .== [1 2;1 2;1 2]) + + push!(b, [0,0,0]) + + @test size(b) == (3, 2) +end \ No newline at end of file diff --git a/test/utils/utils.jl b/test/utils/utils.jl index 7a7715c..bca9bdb 100644 --- a/test/utils/utils.jl +++ b/test/utils/utils.jl @@ -1,3 +1,4 @@ include("base.jl") include("circular_array_buffer.jl") +include("reservoir_array_buffer.jl") include("device.jl")