Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/components/policies/policies.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("V_based_policy.jl")
include("Q_based_policy.jl")
include("off_policy.jl")
include("static_policy.jl")
24 changes: 24 additions & 0 deletions src/components/policies/static_policy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
export StaticPolicy

using MacroTools: @forward

"""
StaticPolicy(policy)

Create a policy wrapper so that it will do nothing when calling
`update!(policy::StaticPolicy, args...)`. Usually used in the
distributed mode as a worker.
"""
struct StaticPolicy{P<:AbstractPolicy} <: AbstractPolicy
p::P
end

(π::StaticPolicy)(env) = π.p(env)

@forward StaticPolicy.p RLBase.get_priority, RLBase.get_prob

RLBase.update!(p::StaticPolicy, args...) = nothing

RLBase.update!(p::StaticPolicy, ps::Params) = update!(p.p, ps)

Flux.@functor StaticPolicy
1 change: 1 addition & 0 deletions src/components/trajectories/trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("abstract_trajectory.jl")
include("trajectory.jl")
include("trajectory_extension.jl")
include("reservoir_trajectory.jl")
65 changes: 63 additions & 2 deletions src/components/trajectories/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ export Trajectory,
ElasticCompactSARTSATrajectory,
CircularCompactPSARTSATrajectory,
CircularCompactSALRTSALTrajectory,
CircularCompactPSALRTSALTrajectory
CircularCompactPSALRTSALTrajectory,
VectSARTSATrajectory,
CircularSARTSATrajectory

using MacroTools: @forward
using ElasticArrays
Expand Down Expand Up @@ -300,6 +302,66 @@ function VectCompactSARTSATrajectory(; reward_type = Float32, terminal_type = Bo
)
end

#####
# VectSARTSATrajectory
#####

const VectSARTSATrajectory = Trajectory{
<:NamedTuple{
(:state, :action, :reward, :terminal, :next_state, :next_action),
<:Tuple{<:Vector, <:Vector, <:Vector, <:Vector, <:Vector, <:Vector}}}

function VectSARTSATrajectory(
;state_type = Int,
action_type=Int,
reward_type=Float32,
terminal_type=Bool,
next_state_type=state_type,
next_action_type=action_type)
Trajectory(
;state=Vector{state_type}(),
action=Vector{action_type}(),
reward=Vector{reward_type}(),
terminal=Vector{terminal_type}(),
next_state=Vector{next_state_type}(),
next_action=Vector{next_action_type}(),
)
end

Base.length(t::VectSARTSATrajectory) = length(t[:state])

#####
# CircularSARTSATrajectory
#####

const CircularSARTSATrajectory = Trajectory{
<:NamedTuple{
(:state, :action, :reward, :terminal, :next_state, :next_action),
<:Tuple{<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer}}}

function CircularSARTSATrajectory(;
capacity,
state_type = Float32,
state_size = (),
action_type = Int,
action_size = (),
reward_type = Float32,
reward_size = (),
terminal_type = Bool,
terminal_size = (),
)
Trajectory(
state = CircularArrayBuffer{state_type}(state_size..., capacity),
action = CircularArrayBuffer{action_type}(action_size..., capacity),
reward = CircularArrayBuffer{reward_type}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity),
next_state = CircularArrayBuffer{state_type}(state_size..., capacity),
next_action = CircularArrayBuffer{action_type}(action_size..., capacity),
)
end

Base.length(t::CircularSARTSATrajectory) = length(t[:state])

#####
# CircularCompactSARTSATrajectory
#####
Expand Down Expand Up @@ -356,7 +418,6 @@ function ElasticCompactSARTSATrajectory(;
)
end


#####
# CircularCompactSALRTSALTrajectory
#####
Expand Down
52 changes: 52 additions & 0 deletions src/components/trajectories/trajectory_extension.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
export NStepInserter, UniformBatchSampler

using Random

#####
# Inserters
#####

abstract type AbstractInserter end

Base.@kwdef struct NStepInserter <: AbstractInserter
n::Int = 1
end

function Base.push!(t::CircularSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory, adder::NStepInserter)
N = length(𝕥[:terminal])
n = adder.n
for i in 1:(N-n+1)
push!(t;
state=select_last_dim(𝕥[:state], i),
action=select_last_dim(𝕥[:action], i),
reward=select_last_dim(𝕥[:reward], i),
terminal=select_last_dim(𝕥[:terminal], i),
next_state=select_last_dim(𝕥[:next_state], i+n-1),
next_action=select_last_dim(𝕥[:next_action], i+n-1),
)
end
end

#####
# Samplers
#####

abstract type AbstractSampler end

struct UniformBatchSampler <: AbstractSampler
batch_size::Int
end

StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) = sample(Random.GLOBAL_RNG, t, sampler)

function StatsBase.sample(rng::AbstractRNG, t::Union{VectSARTSATrajectory, CircularSARTSATrajectory}, sampler::UniformBatchSampler)
inds = rand(rng, 1:length(t), sampler.batch_size)
(
state=Flux.batch(t[:state][inds]),
action=Flux.batch(t[:action][inds]),
reward=Flux.batch(t[:reward][inds]),
terminal=Flux.batch(t[:terminal][inds]),
next_state=Flux.batch(t[:next_state][inds]),
next_action=Flux.batch(t[:next_action][inds]),
)
end
20 changes: 19 additions & 1 deletion src/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ export AbstractHook,
CumulativeReward,
TimePerStep,
DoEveryNEpisode,
DoEveryNStep
DoEveryNStep,
UploadTrajectoryEveryNStep

"""
A hook is called at different stage duiring a [`run`](@ref) to allow users to inject customized runtime logic.
Expand Down Expand Up @@ -337,3 +338,20 @@ function (hook::DoEveryNEpisode)(::PostEpisodeStage, agent, env)
hook.f(hook.t, agent, env)
end
end

"""
UploadTrajectoryEveryNStep(;mailbox, n, sealer=deepcopy)
"""
Base.@kwdef mutable struct UploadTrajectoryEveryNStep{M,S} <: AbstractHook
mailbox::M
n::Int
t::Int = -1
sealer::S = deepcopy
end

function (hook::UploadTrajectoryEveryNStep)(::PostActStage, agent, env)
hook.t += 1
if hook.t > 0 && hook.t % hook.n == 0
put!(hook.mailbox, hook.sealer(agent.trajectory))
end
end
34 changes: 27 additions & 7 deletions src/core/stop_conditions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export StopAfterStep, StopAfterEpisode, StopWhenDone, ComposedStopCondition
export StopAfterStep, StopAfterEpisode, StopWhenDone, ComposedStopCondition, StopSignal

using ProgressMeter

Expand All @@ -9,18 +9,18 @@ const update! = ReinforcementLearningBase.update!
#####

"""
ComposedStopCondition(stop_conditions; reducer = any)
ComposedStopCondition(stop_conditions...; reducer = any)

The result of `stop_conditions` is reduced by `reducer`.
"""
struct ComposedStopCondition{T<:Function}
stop_conditions::Vector{Any}
struct ComposedStopCondition{S,T}
stop_conditions::S
reducer::T
function ComposedStopCondition(stop_conditions...; reducer = any)
new{typeof(stop_conditions), typeof(reducer)}(stop_conditions, reducer)
end
end

ComposedStopCondition(stop_conditions; reducer = any) =
ComposedStopCondition(stop_conditions, reducer)

function (s::ComposedStopCondition)(args...)
s.reducer(sc(args...) for sc in s.stop_conditions)
end
Expand Down Expand Up @@ -114,3 +114,23 @@ Return `true` if the environment is terminated.
struct StopWhenDone end

(s::StopWhenDone)(agent, env) = get_terminal(env)

#####
# StopSignal
#####

"""
StopSignal()

Create a stop signal initialized with a value of `false`.
You can manually set it to `true` by `s[] = true` to stop
the running loop at any time.
"""
Base.@kwdef struct StopSignal
is_stop::Ref{Bool} = Ref(false)
end

Base.getindex(s::StopSignal) = s.is_stop[]
Base.setindex!(s::StopSignal, v::Bool) = s.is_stop[] = v

(s::StopSignal)(agent, env) = s[]
2 changes: 1 addition & 1 deletion src/utils/circular_array_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ end

`update!` the last frame of `cb` with data.
"""
function RLBase.update!(cb::CircularArrayBuffer{T,N}, data::AbstractArray) where {T,N}
function RLBase.update!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
select_last_dim(cb.buffer, _buffer_frame(cb, cb.nframes)) .= data
cb
end
Expand Down