diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index 0475ba1..7375274 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -22,3 +22,10 @@ end orthogonal(dims...) = orthogonal(Random.GLOBAL_RNG, dims...) orthogonal(rng::AbstractRNG) = (dims...) -> orthogonal(rng, dims...) + +function batch!(data, xs) + for (i, x) in enumerate(xs) + data[Flux.batchindex(data, i)...] = x + end + data +end \ No newline at end of file diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index 8f59a01..7fcbf2b 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -34,6 +34,8 @@ end abstract type AbstractSampler{traces} end +# TODO: deprecate this method with `(s::AbstractSampler)(traj)` instead + """ sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=Val(keys(trajectory))]) @@ -46,52 +48,86 @@ function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) sample(Random.GLOBAL_RNG, t, sampler) end +# TODO: add an async batch sampler to pre-fetch next batch + ##### ## BatchSampler ##### -struct BatchSampler{traces} <: AbstractSampler{traces} +mutable struct BatchSampler{traces} <: AbstractSampler{traces} batch_size::Int + cache::Any + rng::Any end -BatchSampler(batch_size::Int) = BatchSampler{SARTSA}(batch_size) +BatchSampler(batch_size::Int; cache=nothing, rng=Random.GLOBAL_RNG) = BatchSampler{SARTSA}(batch_size, cache, rng) +BatchSampler{T}(batch_size::Int; cache=nothing, rng=Random.GLOBAL_RNG) where T = BatchSampler{T}(batch_size, cache, rng) + +(s::BatchSampler)(t::AbstractTrajectory) = sample(s.rng, t, s) +# TODO: deprecate function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::BatchSampler) inds = rand(rng, 1:length(t), s.batch_size) - inds, select(inds, t, s) + fetch!(s, t, inds) + inds, s.cache end -function select( - inds::Vector{Int}, - t::CircularVectorSARTSATrajectory, +function fetch!( s::BatchSampler{traces}, + t::CircularVectorSARTSATrajectory, + inds::Vector{Int}, ) where {traces} - NamedTuple{SARTSA}(Flux.batch(view(t[x], inds)) for x in traces) + batch = NamedTuple{traces}(view(t[x], inds) for x in traces) + if isnothing(s.cache) + s.cache = map(Flux.batch, batch) + else + map(s.cache, batch) do dest, src + batch!(dest, src) + end + end end -function select(inds::Vector{Int}, t::CircularArraySARTTrajectory, s::BatchSampler{SARTS}) - NamedTuple{SARTS}(( - (convert(Array, consecutive_view(t[x], inds)) for x in SART)..., - convert(Array, consecutive_view(t[:state], inds .+ 1)), - )) +function fetch!( + s::BatchSampler{SARTS}, + t::CircularArraySARTTrajectory, + inds::Vector{Int} +) + batch = NamedTuple{SARTS}( + ( + (consecutive_view(t[x], inds) for x in SART)..., + consecutive_view(t[:state], inds .+ 1), + ) + ) + if isnothing(s.cache) + s.cache = map(batch) do x + convert(Array, x) + end + else + map(s.cache, batch) do dest, src + copyto!(dest, src) + end + end end ##### ## NStepBatchSampler ##### -Base.@kwdef struct NStepBatchSampler{traces} <: AbstractSampler{traces} +Base.@kwdef mutable struct NStepBatchSampler{traces} <: AbstractSampler{traces} γ::Float32 n::Int = 1 batch_size::Int = 32 stack_size::Union{Nothing,Int} = nothing + rng::Any = Random.GLOBAL_RNG + cache::Any = nothing end +# TODO:deprecate function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler) valid_range = isnothing(s.stack_size) ? (1:(length(t)-s.n+1)) : (s.stack_size:(length(t)-s.n+1)) inds = rand(rng, valid_range, s.batch_size) - inds, select(inds, t, s) + inds, fetch!(s, t, inds) end function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory, s::NStepBatchSampler) @@ -107,24 +143,26 @@ function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory, s::NStepBa inds[i] = ind priorities[i] = p end - inds, (priority = priorities, select(inds, t.traj, s)...) + inds, (priority = priorities, fetch!(s, t.traj, inds)...) end -function select( - inds::Vector{Int}, +function fetch!( + sampler::NStepBatchSampler{traces}, traj::CircularArraySARTTrajectory, - s::NStepBatchSampler{traces}, + inds::Vector{Int}, ) where {traces} - γ, n, bz, sz = s.γ, s.n, s.batch_size, s.stack_size + γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size + cache = sampler.cache next_inds = inds .+ n - s = convert(Array, consecutive_view(traj[:state], inds; n_stack = sz)) - a = convert(Array, consecutive_view(traj[:action], inds)) - s′ = convert(Array, consecutive_view(traj[:state], next_inds; n_stack = sz)) + s = consecutive_view(traj[:state], inds; n_stack = sz) + a = consecutive_view(traj[:action], inds) + s′ = consecutive_view(traj[:state], next_inds; n_stack = sz) consecutive_rewards = consecutive_view(traj[:reward], inds; n_horizon = n) consecutive_terminals = consecutive_view(traj[:terminal], inds; n_horizon = n) - r, t = zeros(Float32, bz), fill(false, bz) + r = isnothing(cache) ? zeros(Float32, bz) : cache.reward + t = isnothing(cache) ? fill(false, bz) : cache.terminal # make sure that we only consider experiences in current episode for i in 1:bz @@ -139,12 +177,22 @@ function select( end if traces == SARTS - NamedTuple{SARTS}((s, a, r, t, s′)) + batch = NamedTuple{SARTS}((s, a, r, t, s′)) elseif traces == SLARTSL - l = convert(Array, consecutive_view(traj[:legal_actions_mask], inds)) - l′ = convert(Array, consecutive_view(traj[:next_legal_actions_mask], next_inds)) - NamedTuple{SLARTSL}((s, l, a, r, t, s′, l′)) + l = consecutive_view(traj[:legal_actions_mask], inds) + l′ = consecutive_view(traj[:next_legal_actions_mask], next_inds) + batch = NamedTuple{SLARTSL}((s, l, a, r, t, s′, l′)) else @error "unsupported traces $traces" end + + if isnothing(sampler.cache) + sampler.cache = map(batch) do x + convert(Array, x) + end + else + map(sampler.cache, batch) do dest, src + copyto!(dest, src) + end + end end