Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/components/trajectories/reservoir_trajectory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
export ReservoirTrajectory

using MacroTools: @forward
using Random

mutable struct ReservoirTrajectory{B,R<:AbstractRNG} <: AbstractTrajectory
buffer::B
n::Int
capacity::Int
rng::R
end

@forward ReservoirTrajectory.buffer Base.keys, Base.haskey, Base.getindex

Base.length(x::ReservoirTrajectory) = length(x.buffer[1])

function ReservoirTrajectory(capacity, kw::Pair{Symbol, DataType}...;n=0, rng=Random.GLOBAL_RNG)
buffer = Trajectory(;(s => Vector{t}() for (s,t) in kw)...)
ReservoirTrajectory(buffer, n, capacity, rng)
end

function Base.push!(b::ReservoirTrajectory; kw...)
b.n += 1
if b.n <= b.capacity
push!(b.buffer; kw...)
else
i = rand(b.rng, 1:b.n)
if i <= b.capacity
for (k,v) in kw
b.buffer[k][i] = v
end
end
end
end
1 change: 1 addition & 0 deletions src/components/trajectories/trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("abstract_trajectory.jl")
include("trajectory.jl")
include("reservoir_trajectory.jl")
8 changes: 4 additions & 4 deletions src/utils/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ send_to_device(::Val{:cpu}, x::CuArray) = adapt(Array, x)
send_to_device(::Val{:gpu}, x) = Flux.fmap(a -> adapt(CuArray{Float32}, a), x)

const KnownArrayVariants = Union{
SubArray{<:Any,<:Any,<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray}},
SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}},
Base.ReshapedArray{
<:Any,
<:Any,
<:SubArray{
<:Any,
<:Any,
<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray},
<:Union{CircularArrayBuffer,ElasticArray},
},
},
Base.ReshapedArray{
<:Any,
<:Any,
<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray},
<:Union{CircularArrayBuffer,ElasticArray},
},
SubArray{
<:Any,
Expand All @@ -39,7 +39,7 @@ const KnownArrayVariants = Union{
<:SubArray{
<:Any,
<:Any,
<:Union{ReservoirArrayBuffer,CircularArrayBuffer,ElasticArray},
<:Union{CircularArrayBuffer,ElasticArray},
},
},
},
Expand Down
37 changes: 0 additions & 37 deletions src/utils/reservoir_array_buffer.jl

This file was deleted.

1 change: 0 additions & 1 deletion src/utils/utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
include("printing.jl")
include("base.jl")
include("circular_array_buffer.jl")
include("reservoir_array_buffer.jl")
include("device.jl")
include("sum_tree.jl")
37 changes: 37 additions & 0 deletions test/components/trajectories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,41 @@
empty!(t)
@test length(t[:state]) == 0
end

@testset "ReservoirTrajectory" begin
# test length
t = ReservoirTrajectory(3, :a=>Array{Float64,2}, :b=>Bool)
push!(t;a=rand(2,3),b=rand(Bool))
@test length(t) == 1
push!(t;a=rand(2,3),b=rand(Bool))
@test length(t) == 2
push!(t;a=rand(2,3),b=rand(Bool))
@test length(t) == 3

for _ in 1:100
push!(t;a=rand(2,3),b=rand(Bool))
end

@test length(t) == 3

# test distribution

Random.seed!(110)
k, n, N = 3, 10, 10000
stats = Dict(i => 0 for i in 1:n)
for _ in 1:N
t = ReservoirTrajectory(k, :a=>Array{Int, 2}, :b=>Int)
for i in 1:n
push!(t;a=i .* ones(Int, 2, 3), b=i)
end

for i in 1:length(t)
stats[t[:b][i]] += 1
end
end

for v in values(stats)
@test isapprox(v/N, k/n;atol=0.03)
end
end
end
16 changes: 0 additions & 16 deletions test/utils/reservoir_array_buffer.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/utils/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
include("base.jl")
include("circular_array_buffer.jl")
include("reservoir_array_buffer.jl")
include("device.jl")