diff --git a/src/components/trajectories/reservoir_trajectory.jl b/src/components/trajectories/reservoir_trajectory.jl new file mode 100644 index 0000000..0d3e345 --- /dev/null +++ b/src/components/trajectories/reservoir_trajectory.jl @@ -0,0 +1,34 @@ +export ReservoirTrajectory + +using MacroTools: @forward +using Random + +mutable struct ReservoirTrajectory{B,R<:AbstractRNG} <: AbstractTrajectory + buffer::B + n::Int + capacity::Int + rng::R +end + +@forward ReservoirTrajectory.buffer Base.keys, Base.haskey, Base.getindex + +Base.length(x::ReservoirTrajectory) = length(x.buffer[1]) + +function ReservoirTrajectory(capacity, kw::Pair{Symbol, DataType}...;n=0, rng=Random.GLOBAL_RNG) + buffer = Trajectory(;(s => Vector{t}() for (s,t) in kw)...) + ReservoirTrajectory(buffer, n, capacity, rng) +end + +function Base.push!(b::ReservoirTrajectory; kw...) + b.n += 1 + if b.n <= b.capacity + push!(b.buffer; kw...) + else + i = rand(b.rng, 1:b.n) + if i <= b.capacity + for (k,v) in kw + b.buffer[k][i] = v + end + end + end +end \ No newline at end of file diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl index b7fb270..614c74f 100644 --- a/src/components/trajectories/trajectories.jl +++ b/src/components/trajectories/trajectories.jl @@ -1,2 +1,3 @@ include("abstract_trajectory.jl") include("trajectory.jl") +include("reservoir_trajectory.jl") diff --git a/src/utils/device.jl b/src/utils/device.jl index bef7ac9..ef2fdc2 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -15,20 +15,20 @@ send_to_device(::Val{:cpu}, x::CuArray) = adapt(Array, x) send_to_device(::Val{:gpu}, x) = Flux.fmap(a -> adapt(CuArray{Float32}, a), x) const KnownArrayVariants = Union{ - SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}}, + SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, Base.ReshapedArray{ <:Any, <:Any, <:SubArray{ <:Any, <:Any, - <:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}, + <:Union{CircularArrayBuffer,ElasticArray}, }, }, Base.ReshapedArray{ <:Any, <:Any, - <:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}, + <:Union{CircularArrayBuffer,ElasticArray}, }, SubArray{ <:Any, @@ -39,7 +39,7 @@ const KnownArrayVariants = Union{ <:SubArray{ <:Any, <:Any, - <:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}, + <:Union{CircularArrayBuffer,ElasticArray}, }, }, }, diff --git a/src/utils/reservoir_array_buffer.jl b/src/utils/reservoir_array_buffer.jl deleted file mode 100644 index 16d6f02..0000000 --- a/src/utils/reservoir_array_buffer.jl +++ /dev/null @@ -1,37 +0,0 @@ -export ReservoirArrayBuffer - -using Random -using ElasticArrays -using MacroTools: @forward - -mutable struct ReservoirArrayBuffer{T,N,B<:ElasticArray{T,N},R<:AbstractRNG} <: - AbstractArray{T,N} - buffer::B - n::Int - capacity::Int - rng::R -end - -ReservoirArrayBuffer{T}(dims::Int...; rng = Random.GLOBAL_RNG) where {T} = - ReservoirArrayBuffer(ElasticArray{T}(undef, dims[1:end-1]..., 0), 0, dims[end], rng) - -@forward ReservoirArrayBuffer.buffer Base.size, -Base.getindex, -Base.length, -Base.sizeof, -Base.IndexStyle - -# TODO: rename all push! to append! - -function Base.push!(b::ReservoirArrayBuffer{T,N}, x) where {T,N} - b.n += 1 - if b.n <= b.capacity - push!(b.buffer, x) - else - i = rand(b.rng, 1:b.n) - if i <= b.capacity - stride = b.buffer.kernel_length.divisor - b.buffer.data[(stride*(i-1)+1):stride*i] .= x - end - end -end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 9ba0032..be1556d 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,6 +1,5 @@ include("printing.jl") include("base.jl") include("circular_array_buffer.jl") -include("reservoir_array_buffer.jl") include("device.jl") include("sum_tree.jl") diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl index c7e39a2..8ec8574 100644 --- a/test/components/trajectories.jl +++ b/test/components/trajectories.jl @@ -207,4 +207,41 @@ empty!(t) @test length(t[:state]) == 0 end + + @testset "ReservoirTrajectory" begin + # test length + t = ReservoirTrajectory(3, :a=>Array{Float64,2}, :b=>Bool) + push!(t;a=rand(2,3),b=rand(Bool)) + @test length(t) == 1 + push!(t;a=rand(2,3),b=rand(Bool)) + @test length(t) == 2 + push!(t;a=rand(2,3),b=rand(Bool)) + @test length(t) == 3 + + for _ in 1:100 + push!(t;a=rand(2,3),b=rand(Bool)) + end + + @test length(t) == 3 + + # test distribution + + Random.seed!(110) + k, n, N = 3, 10, 10000 + stats = Dict(i => 0 for i in 1:n) + for _ in 1:N + t = ReservoirTrajectory(k, :a=>Array{Int, 2}, :b=>Int) + for i in 1:n + push!(t;a=i .* ones(Int, 2, 3), b=i) + end + + for i in 1:length(t) + stats[t[:b][i]] += 1 + end + end + + for v in values(stats) + @test isapprox(v/N, k/n;atol=0.03) + end + end end diff --git a/test/utils/reservoir_array_buffer.jl b/test/utils/reservoir_array_buffer.jl deleted file mode 100644 index 321a5dc..0000000 --- a/test/utils/reservoir_array_buffer.jl +++ /dev/null @@ -1,16 +0,0 @@ -@testset "ReservoirArrayBuffer" begin - b = ReservoirArrayBuffer{Int}(3, 2) - @assert size(b) == (3, 0) - - push!(b, [1, 1, 1]) - @assert size(b) == (3, 1) - @test all(b .== [1; 1; 1]) - - push!(b, [2, 2, 2]) - @assert size(b) == (3, 2) - @test all(b .== [1 2; 1 2; 1 2]) - - push!(b, [0, 0, 0]) - - @test size(b) == (3, 2) -end diff --git a/test/utils/utils.jl b/test/utils/utils.jl index bca9bdb..7a7715c 100644 --- a/test/utils/utils.jl +++ b/test/utils/utils.jl @@ -1,4 +1,3 @@ include("base.jl") include("circular_array_buffer.jl") -include("reservoir_array_buffer.jl") include("device.jl")