Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/extensions/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ end

orthogonal(dims...) = orthogonal(Random.GLOBAL_RNG, dims...)
orthogonal(rng::AbstractRNG) = (dims...) -> orthogonal(rng, dims...)

function batch!(data, xs)
for (i, x) in enumerate(xs)
data[Flux.batchindex(data, i)...] = x
end
data
end
102 changes: 75 additions & 27 deletions src/policies/agents/trajectories/trajectory_extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ end

abstract type AbstractSampler{traces} end

# TODO: deprecate this method with `(s::AbstractSampler)(traj)` instead

"""
sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=Val(keys(trajectory))])

Expand All @@ -46,52 +48,86 @@ function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler)
sample(Random.GLOBAL_RNG, t, sampler)
end

# TODO: add an async batch sampler to pre-fetch next batch

#####
## BatchSampler
#####

struct BatchSampler{traces} <: AbstractSampler{traces}
mutable struct BatchSampler{traces} <: AbstractSampler{traces}
batch_size::Int
cache::Any
rng::Any
end

BatchSampler(batch_size::Int) = BatchSampler{SARTSA}(batch_size)
BatchSampler(batch_size::Int; cache=nothing, rng=Random.GLOBAL_RNG) = BatchSampler{SARTSA}(batch_size, cache, rng)
BatchSampler{T}(batch_size::Int; cache=nothing, rng=Random.GLOBAL_RNG) where T = BatchSampler{T}(batch_size, cache, rng)

(s::BatchSampler)(t::AbstractTrajectory) = sample(s.rng, t, s)

# TODO: deprecate
function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::BatchSampler)
inds = rand(rng, 1:length(t), s.batch_size)
inds, select(inds, t, s)
fetch!(s, t, inds)
inds, s.cache
end

function select(
inds::Vector{Int},
t::CircularVectorSARTSATrajectory,
function fetch!(
s::BatchSampler{traces},
t::CircularVectorSARTSATrajectory,
inds::Vector{Int},
) where {traces}
NamedTuple{SARTSA}(Flux.batch(view(t[x], inds)) for x in traces)
batch = NamedTuple{traces}(view(t[x], inds) for x in traces)
if isnothing(s.cache)
s.cache = map(Flux.batch, batch)
else
map(s.cache, batch) do dest, src
batch!(dest, src)
end
end
end

function select(inds::Vector{Int}, t::CircularArraySARTTrajectory, s::BatchSampler{SARTS})
NamedTuple{SARTS}((
(convert(Array, consecutive_view(t[x], inds)) for x in SART)...,
convert(Array, consecutive_view(t[:state], inds .+ 1)),
))
function fetch!(
s::BatchSampler{SARTS},
t::CircularArraySARTTrajectory,
inds::Vector{Int}
)
batch = NamedTuple{SARTS}(
(
(consecutive_view(t[x], inds) for x in SART)...,
consecutive_view(t[:state], inds .+ 1),
)
)
if isnothing(s.cache)
s.cache = map(batch) do x
convert(Array, x)
end
else
map(s.cache, batch) do dest, src
copyto!(dest, src)
end
end
end

#####
## NStepBatchSampler
#####

Base.@kwdef struct NStepBatchSampler{traces} <: AbstractSampler{traces}
Base.@kwdef mutable struct NStepBatchSampler{traces} <: AbstractSampler{traces}
γ::Float32
n::Int = 1
batch_size::Int = 32
stack_size::Union{Nothing,Int} = nothing
rng::Any = Random.GLOBAL_RNG
cache::Any = nothing
end

# TODO:deprecate
function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler)
valid_range =
isnothing(s.stack_size) ? (1:(length(t)-s.n+1)) : (s.stack_size:(length(t)-s.n+1))
inds = rand(rng, valid_range, s.batch_size)
inds, select(inds, t, s)
inds, fetch!(s, t, inds)
end

function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory, s::NStepBatchSampler)
Expand All @@ -107,24 +143,26 @@ function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory, s::NStepBa
inds[i] = ind
priorities[i] = p
end
inds, (priority = priorities, select(inds, t.traj, s)...)
inds, (priority = priorities, fetch!(s, t.traj, inds)...)
end

function select(
inds::Vector{Int},
function fetch!(
sampler::NStepBatchSampler{traces},
traj::CircularArraySARTTrajectory,
s::NStepBatchSampler{traces},
inds::Vector{Int},
) where {traces}
γ, n, bz, sz = s.γ, s.n, s.batch_size, s.stack_size
γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size
cache = sampler.cache
next_inds = inds .+ n

s = convert(Array, consecutive_view(traj[:state], inds; n_stack = sz))
a = convert(Array, consecutive_view(traj[:action], inds))
s′ = convert(Array, consecutive_view(traj[:state], next_inds; n_stack = sz))
s = consecutive_view(traj[:state], inds; n_stack = sz)
a = consecutive_view(traj[:action], inds)
s′ = consecutive_view(traj[:state], next_inds; n_stack = sz)

consecutive_rewards = consecutive_view(traj[:reward], inds; n_horizon = n)
consecutive_terminals = consecutive_view(traj[:terminal], inds; n_horizon = n)
r, t = zeros(Float32, bz), fill(false, bz)
r = isnothing(cache) ? zeros(Float32, bz) : cache.reward
t = isnothing(cache) ? fill(false, bz) : cache.terminal

# make sure that we only consider experiences in current episode
for i in 1:bz
Expand All @@ -139,12 +177,22 @@ function select(
end

if traces == SARTS
NamedTuple{SARTS}((s, a, r, t, s′))
batch = NamedTuple{SARTS}((s, a, r, t, s′))
elseif traces == SLARTSL
l = convert(Array, consecutive_view(traj[:legal_actions_mask], inds))
l′ = convert(Array, consecutive_view(traj[:next_legal_actions_mask], next_inds))
NamedTuple{SLARTSL}((s, l, a, r, t, s′, l′))
l = consecutive_view(traj[:legal_actions_mask], inds)
l′ = consecutive_view(traj[:next_legal_actions_mask], next_inds)
batch = NamedTuple{SLARTSL}((s, l, a, r, t, s′, l′))
else
@error "unsupported traces $traces"
end

if isnothing(sampler.cache)
sampler.cache = map(batch) do x
convert(Array, x)
end
else
map(sampler.cache, batch) do dest, src
copyto!(dest, src)
end
end
end