From 1d9b010c2330dc58dd87947cc3f390b66c71a5cd Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 8 Nov 2020 23:42:10 +0800 Subject: [PATCH 1/4] add StaticPolicy --- src/components/policies/policies.jl | 1 + src/components/policies/static_policy.jl | 22 ++++++++ .../trajectories/distributed_trajectory.jl | 55 +++++++++++++++++++ src/components/trajectories/trajectories.jl | 1 + 4 files changed, 79 insertions(+) create mode 100644 src/components/policies/static_policy.jl create mode 100644 src/components/trajectories/distributed_trajectory.jl diff --git a/src/components/policies/policies.jl b/src/components/policies/policies.jl index ab599f9..6c04293 100644 --- a/src/components/policies/policies.jl +++ b/src/components/policies/policies.jl @@ -1,3 +1,4 @@ include("V_based_policy.jl") include("Q_based_policy.jl") include("off_policy.jl") +include("static_policy.jl") diff --git a/src/components/policies/static_policy.jl b/src/components/policies/static_policy.jl new file mode 100644 index 0000000..3ae4f85 --- /dev/null +++ b/src/components/policies/static_policy.jl @@ -0,0 +1,22 @@ +export StaticPolicy + +using MacroTools: @forward + +""" + StaticPolicy(policy) + +Create a policy wrapper so that it will do nothing when calling +`update!(policy::StaticPolicy, args...)`. Usually used in the +distributed mode as a worker. +""" +struct StaticPolicy{P<:AbstractPolicy} <: AbstractPolicy + p::P +end + +(π::StaticPolicy)(env) = π.p(env) + +@forward OffPolicy.p RLBase.get_priority, RLBase.get_prob + +RLBase.update!(p::StaticPolicy, args...) = nothing + +RLBase.update!(p::StaticPolicy, ps::Params) = update!(p.p, ps) \ No newline at end of file diff --git a/src/components/trajectories/distributed_trajectory.jl b/src/components/trajectories/distributed_trajectory.jl new file mode 100644 index 0000000..2108677 --- /dev/null +++ b/src/components/trajectories/distributed_trajectory.jl @@ -0,0 +1,55 @@ +export TrajectoryClient + +using MacroTools: @forward + +##### +# Client part +##### + +struct TrajectoryClient{T<:AbstractTrajectory, S} <: AbstractTrajectory + trajectory::T + bulk_size::Int + mailbox::S +end + +@forward TrajectoryClient.trajectory Base.keys, +Base.haskey, +Base.getindex, +Base.pop!, +Base.empty!, +isfull + +function Base.push!(t::TrajectoryClient, args...;kwargs...) + push!(t.trajectory, args...;kwargs...) + _sync(t) +end + +# Given that CircularCompactSARTSATrajectory is the most common one +# We'll focus on the implementations around it for now + +function _sync(t::TrajectoryClient{CircularCompactSARTSATrajectory}) + if nframes(t.trajectory[:full_state]) >= t.bulk_size + # TODO: here we simply create an copy to avoid sharing the same data accross different tasks + # But for remote channels, this is redundant because it will always copy the data. + d = deepcopy(t.trajectory) + put!(t.mailbox, d) + end +end + +##### +# Server part +##### + +function Base.push!(t::VectSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory) + for i in 1:length(𝕥[:terminal]) + push!( + t; + state=select_last_dim(𝕥[:state], i), + action=select_last_dim(𝕥[:action], i), + reward=select_last_dim(𝕥[:reward], i), + terminal=select_last_dim(𝕥[:terminal], i), + next_state=select_last_dim(𝕥[:next_state], i), + next_action=select_last_dim(𝕥[:next_action], i), + ) + end +end \ No newline at end of file diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl index 614c74f..cc49546 100644 --- a/src/components/trajectories/trajectories.jl +++ b/src/components/trajectories/trajectories.jl @@ -1,3 +1,4 @@ include("abstract_trajectory.jl") include("trajectory.jl") +include("distributed_trajectory.jl") include("reservoir_trajectory.jl") From 2d07418b141b8123c5172e17b8f6ae70c0d40738 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 9 Nov 2020 18:27:00 +0800 Subject: [PATCH 2/4] sync --- src/components/policies/static_policy.jl | 2 +- .../trajectories/distributed_trajectory.jl | 47 +++++++++++-------- src/components/trajectories/trajectory.jl | 30 +++++++++++- 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/components/policies/static_policy.jl b/src/components/policies/static_policy.jl index 3ae4f85..ae8c894 100644 --- a/src/components/policies/static_policy.jl +++ b/src/components/policies/static_policy.jl @@ -15,7 +15,7 @@ end (π::StaticPolicy)(env) = π.p(env) -@forward OffPolicy.p RLBase.get_priority, RLBase.get_prob +@forward StaticPolicy.p RLBase.get_priority, RLBase.get_prob RLBase.update!(p::StaticPolicy, args...) = nothing diff --git a/src/components/trajectories/distributed_trajectory.jl b/src/components/trajectories/distributed_trajectory.jl index 2108677..3a2ed14 100644 --- a/src/components/trajectories/distributed_trajectory.jl +++ b/src/components/trajectories/distributed_trajectory.jl @@ -6,10 +6,12 @@ using MacroTools: @forward # Client part ##### -struct TrajectoryClient{T<:AbstractTrajectory, S} <: AbstractTrajectory +mutable struct TrajectoryClient{T<:AbstractTrajectory,A,S} <: AbstractTrajectory trajectory::T - bulk_size::Int + adder::A mailbox::S + sync_freq::Int + n::Int end @forward TrajectoryClient.trajectory Base.keys, @@ -21,35 +23,40 @@ isfull function Base.push!(t::TrajectoryClient, args...;kwargs...) push!(t.trajectory, args...;kwargs...) - _sync(t) + t.n += 1 + + if t.n % t.sync_freq == 0 + put!(t.mailbox, deepcopy(t.trajectoryj)) + end end -# Given that CircularCompactSARTSATrajectory is the most common one -# We'll focus on the implementations around it for now -function _sync(t::TrajectoryClient{CircularCompactSARTSATrajectory}) - if nframes(t.trajectory[:full_state]) >= t.bulk_size - # TODO: here we simply create an copy to avoid sharing the same data accross different tasks - # But for remote channels, this is redundant because it will always copy the data. - d = deepcopy(t.trajectory) - put!(t.mailbox, d) - end +##### +# TrajectorySampler +##### + +abstract type AbstractAdder end + +Base.@kwdef struct NStepAdder <: AbstractAdder + n::Int = 1 end ##### # Server part ##### -function Base.push!(t::VectSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory) - for i in 1:length(𝕥[:terminal]) +function Base.push!(t::VectSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory, adder::NStepAdder) + N = length(𝕥[:terminal]) + n = adder.n + for i in 1:(N-n+1) push!( t; - state=select_last_dim(𝕥[:state], i), - action=select_last_dim(𝕥[:action], i), - reward=select_last_dim(𝕥[:reward], i), - terminal=select_last_dim(𝕥[:terminal], i), - next_state=select_last_dim(𝕥[:next_state], i), - next_action=select_last_dim(𝕥[:next_action], i), + state=select_last_dim(𝕥[:state], i:i+n-1), + action=select_last_dim(𝕥[:action], i:i+n-1), + reward=select_last_dim(𝕥[:reward], i:i+n-1), + terminal=select_last_dim(𝕥[:terminal], i:i+n-1), + next_state=select_last_dim(𝕥[:next_state], i:i+n-1), + next_action=select_last_dim(𝕥[:next_action], i:i+n-1), ) end end \ No newline at end of file diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 4531bcc..05f94e2 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -11,7 +11,8 @@ export Trajectory, ElasticCompactSARTSATrajectory, CircularCompactPSARTSATrajectory, CircularCompactSALRTSALTrajectory, - CircularCompactPSALRTSALTrajectory + CircularCompactPSALRTSALTrajectory, + VectSARTSATrajectory using MacroTools: @forward using ElasticArrays @@ -300,6 +301,32 @@ function VectCompactSARTSATrajectory(; reward_type = Float32, terminal_type = Bo ) end +##### +# VectSARTSATrajectory +##### + +const VectSARTSATrajectory = Trajectory{ + <:NamedTuple{ + (:state, :action, :reward, :terminal, :next_state, :next_action), + <:Tuple{<:Vector, <:Vector, <:Vector, <:Vector, <:Vector, <:Vector}}} + +function VectSARTSATrajectory( + ;state_type = Int, + action_type=Int, + reward_type=Float32, + terminal_type=Bool, + next_state_type=state_type, + next_action_type=action_type) + Trajectory( + ;state=Vector{state_type}(), + action=Vector{action_type}(), + reward=Vector{reward_type}(), + terminal=Vector{terminal_type}(), + next_state=Vector{next_state_type}(), + next_action=Vector{next_action_type}(), + ) +end + ##### # CircularCompactSARTSATrajectory ##### @@ -356,7 +383,6 @@ function ElasticCompactSARTSATrajectory(; ) end - ##### # CircularCompactSALRTSALTrajectory ##### From f1683523ad59f594927f438524678ae529d53452 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 12 Nov 2020 21:58:25 +0800 Subject: [PATCH 3/4] sync local changes --- .../trajectories/distributed_trajectory.jl | 62 ------------------- src/components/trajectories/trajectories.jl | 2 +- src/components/trajectories/trajectory.jl | 2 + .../trajectories/trajectory_extension.jl | 52 ++++++++++++++++ src/core/stop_conditions.jl | 22 ++++++- 5 files changed, 76 insertions(+), 64 deletions(-) delete mode 100644 src/components/trajectories/distributed_trajectory.jl create mode 100644 src/components/trajectories/trajectory_extension.jl diff --git a/src/components/trajectories/distributed_trajectory.jl b/src/components/trajectories/distributed_trajectory.jl deleted file mode 100644 index 3a2ed14..0000000 --- a/src/components/trajectories/distributed_trajectory.jl +++ /dev/null @@ -1,62 +0,0 @@ -export TrajectoryClient - -using MacroTools: @forward - -##### -# Client part -##### - -mutable struct TrajectoryClient{T<:AbstractTrajectory,A,S} <: AbstractTrajectory - trajectory::T - adder::A - mailbox::S - sync_freq::Int - n::Int -end - -@forward TrajectoryClient.trajectory Base.keys, -Base.haskey, -Base.getindex, -Base.pop!, -Base.empty!, -isfull - -function Base.push!(t::TrajectoryClient, args...;kwargs...) - push!(t.trajectory, args...;kwargs...) - t.n += 1 - - if t.n % t.sync_freq == 0 - put!(t.mailbox, deepcopy(t.trajectoryj)) - end -end - - -##### -# TrajectorySampler -##### - -abstract type AbstractAdder end - -Base.@kwdef struct NStepAdder <: AbstractAdder - n::Int = 1 -end - -##### -# Server part -##### - -function Base.push!(t::VectSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory, adder::NStepAdder) - N = length(𝕥[:terminal]) - n = adder.n - for i in 1:(N-n+1) - push!( - t; - state=select_last_dim(𝕥[:state], i:i+n-1), - action=select_last_dim(𝕥[:action], i:i+n-1), - reward=select_last_dim(𝕥[:reward], i:i+n-1), - terminal=select_last_dim(𝕥[:terminal], i:i+n-1), - next_state=select_last_dim(𝕥[:next_state], i:i+n-1), - next_action=select_last_dim(𝕥[:next_action], i:i+n-1), - ) - end -end \ No newline at end of file diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl index cc49546..80859af 100644 --- a/src/components/trajectories/trajectories.jl +++ b/src/components/trajectories/trajectories.jl @@ -1,4 +1,4 @@ include("abstract_trajectory.jl") include("trajectory.jl") -include("distributed_trajectory.jl") +include("trajectory_extension.jl") include("reservoir_trajectory.jl") diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 05f94e2..de6ba65 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -327,6 +327,8 @@ function VectSARTSATrajectory( ) end +Base.length(t::VectSARTSATrajectory) = length(t[:state]) + ##### # CircularCompactSARTSATrajectory ##### diff --git a/src/components/trajectories/trajectory_extension.jl b/src/components/trajectories/trajectory_extension.jl new file mode 100644 index 0000000..3d3462b --- /dev/null +++ b/src/components/trajectories/trajectory_extension.jl @@ -0,0 +1,52 @@ +export NStepInserter, UniformBatchSampler + +using Random + +##### +# Inserters +##### + +abstract type AbstractInserter end + +Base.@kwdef struct NStepInserter <: AbstractInserter + n::Int = 1 +end + +function Base.push!(t::VectSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory, adder::NStepInserter) + N = length(𝕥[:terminal]) + n = adder.n + for i in 1:(N-n+1) + push!(t; + state=select_last_dim(𝕥[:state], i), + action=select_last_dim(𝕥[:action], i), + reward=select_last_dim(𝕥[:reward], i), + terminal=select_last_dim(𝕥[:terminal], i), + next_state=select_last_dim(𝕥[:next_state], i+n-1), + next_action=select_last_dim(𝕥[:next_action], i+n-1), + ) + end +end + +##### +# Samplers +##### + +abstract type AbstractSampler end + +struct UniformBatchSampler <: AbstractSampler + batch_size::Int +end + +StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) = sample(Random.GLOBAL_RNG, t, sampler) + +function StatsBase.sample(rng::AbstractRNG, t::VectSARTSATrajectory, sampler::UniformBatchSampler) + inds = rand(rng, 1:length(t), sampler.batch_size) + ( + state=Flux.batch(t[:state][inds]), + action=Flux.batch(t[:action][inds]), + reward=Flux.batch(t[:reward][inds]), + terminal=Flux.batch(t[:terminal][inds]), + next_state=Flux.batch(t[:next_state][inds]), + next_action=Flux.batch(t[:next_action][inds]), + ) +end diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index d0e30a8..db85b1f 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -1,4 +1,4 @@ -export StopAfterStep, StopAfterEpisode, StopWhenDone, ComposedStopCondition +export StopAfterStep, StopAfterEpisode, StopWhenDone, ComposedStopCondition, StopSignal using ProgressMeter @@ -114,3 +114,23 @@ Return `true` if the environment is terminated. struct StopWhenDone end (s::StopWhenDone)(agent, env) = get_terminal(env) + +##### +# StopSignal +##### + +""" + StopSignal() + +Create a stop signal initialized with a value of `false`. +You can manually set it to `true` by `s[] = true` to stop +the running loop at any time. +""" +Base.@kwdef struct StopSignal + is_stop::Ref{Bool} = Ref(false) +end + +Base.getindex(s::StopSignal) = s.is_stop[] +Base.setindex!(s::StopSignal, v::Bool) = s.is_stop[] = v + +(s::StopSignal)(args...) = s[] \ No newline at end of file From 9a808b9c2aba7df31697863fcfe3884914d5e2cc Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 22 Nov 2020 01:57:17 +0800 Subject: [PATCH 4/4] sync --- src/components/policies/static_policy.jl | 4 ++- src/components/trajectories/trajectory.jl | 35 ++++++++++++++++++- .../trajectories/trajectory_extension.jl | 4 +-- src/core/hooks.jl | 20 ++++++++++- src/core/stop_conditions.jl | 14 ++++---- src/utils/circular_array_buffer.jl | 2 +- 6 files changed, 66 insertions(+), 13 deletions(-) diff --git a/src/components/policies/static_policy.jl b/src/components/policies/static_policy.jl index ae8c894..9390de2 100644 --- a/src/components/policies/static_policy.jl +++ b/src/components/policies/static_policy.jl @@ -19,4 +19,6 @@ end RLBase.update!(p::StaticPolicy, args...) = nothing -RLBase.update!(p::StaticPolicy, ps::Params) = update!(p.p, ps) \ No newline at end of file +RLBase.update!(p::StaticPolicy, ps::Params) = update!(p.p, ps) + +Flux.@functor StaticPolicy \ No newline at end of file diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index de6ba65..1bd6347 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -12,7 +12,8 @@ export Trajectory, CircularCompactPSARTSATrajectory, CircularCompactSALRTSALTrajectory, CircularCompactPSALRTSALTrajectory, - VectSARTSATrajectory + VectSARTSATrajectory, + CircularSARTSATrajectory using MacroTools: @forward using ElasticArrays @@ -329,6 +330,38 @@ end Base.length(t::VectSARTSATrajectory) = length(t[:state]) +##### +# CircularSARTSATrajectory +##### + +const CircularSARTSATrajectory = Trajectory{ + <:NamedTuple{ + (:state, :action, :reward, :terminal, :next_state, :next_action), + <:Tuple{<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer,<:CircularArrayBuffer}}} + +function CircularSARTSATrajectory(; + capacity, + state_type = Float32, + state_size = (), + action_type = Int, + action_size = (), + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), +) + Trajectory( + state = CircularArrayBuffer{state_type}(state_size..., capacity), + action = CircularArrayBuffer{action_type}(action_size..., capacity), + reward = CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + next_state = CircularArrayBuffer{state_type}(state_size..., capacity), + next_action = CircularArrayBuffer{action_type}(action_size..., capacity), + ) +end + +Base.length(t::CircularSARTSATrajectory) = length(t[:state]) + ##### # CircularCompactSARTSATrajectory ##### diff --git a/src/components/trajectories/trajectory_extension.jl b/src/components/trajectories/trajectory_extension.jl index 3d3462b..e17c954 100644 --- a/src/components/trajectories/trajectory_extension.jl +++ b/src/components/trajectories/trajectory_extension.jl @@ -12,7 +12,7 @@ Base.@kwdef struct NStepInserter <: AbstractInserter n::Int = 1 end -function Base.push!(t::VectSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory, adder::NStepInserter) +function Base.push!(t::CircularSARTSATrajectory, 𝕥::CircularCompactSARTSATrajectory, adder::NStepInserter) N = length(𝕥[:terminal]) n = adder.n for i in 1:(N-n+1) @@ -39,7 +39,7 @@ end StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) = sample(Random.GLOBAL_RNG, t, sampler) -function StatsBase.sample(rng::AbstractRNG, t::VectSARTSATrajectory, sampler::UniformBatchSampler) +function StatsBase.sample(rng::AbstractRNG, t::Union{VectSARTSATrajectory, CircularSARTSATrajectory}, sampler::UniformBatchSampler) inds = rand(rng, 1:length(t), sampler.batch_size) ( state=Flux.batch(t[:state][inds]), diff --git a/src/core/hooks.jl b/src/core/hooks.jl index eba98dc..9177471 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -9,7 +9,8 @@ export AbstractHook, CumulativeReward, TimePerStep, DoEveryNEpisode, - DoEveryNStep + DoEveryNStep, + UploadTrajectoryEveryNStep """ A hook is called at different stage duiring a [`run`](@ref) to allow users to inject customized runtime logic. @@ -337,3 +338,20 @@ function (hook::DoEveryNEpisode)(::PostEpisodeStage, agent, env) hook.f(hook.t, agent, env) end end + +""" + UploadTrajectoryEveryNStep(;mailbox, n, sealer=deepcopy) +""" +Base.@kwdef mutable struct UploadTrajectoryEveryNStep{M,S} <: AbstractHook + mailbox::M + n::Int + t::Int = -1 + sealer::S = deepcopy +end + +function (hook::UploadTrajectoryEveryNStep)(::PostActStage, agent, env) + hook.t += 1 + if hook.t > 0 && hook.t % hook.n == 0 + put!(hook.mailbox, hook.sealer(agent.trajectory)) + end +end diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index db85b1f..57d1495 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -9,18 +9,18 @@ const update! = ReinforcementLearningBase.update! ##### """ - ComposedStopCondition(stop_conditions; reducer = any) + ComposedStopCondition(stop_conditions...; reducer = any) The result of `stop_conditions` is reduced by `reducer`. """ -struct ComposedStopCondition{T<:Function} - stop_conditions::Vector{Any} +struct ComposedStopCondition{S,T} + stop_conditions::S reducer::T + function ComposedStopCondition(stop_conditions...; reducer = any) + new{typeof(stop_conditions), typeof(reducer)}(stop_conditions, reducer) + end end -ComposedStopCondition(stop_conditions; reducer = any) = - ComposedStopCondition(stop_conditions, reducer) - function (s::ComposedStopCondition)(args...) s.reducer(sc(args...) for sc in s.stop_conditions) end @@ -133,4 +133,4 @@ end Base.getindex(s::StopSignal) = s.is_stop[] Base.setindex!(s::StopSignal, v::Bool) = s.is_stop[] = v -(s::StopSignal)(args...) = s[] \ No newline at end of file +(s::StopSignal)(agent, env) = s[] \ No newline at end of file diff --git a/src/utils/circular_array_buffer.jl b/src/utils/circular_array_buffer.jl index d0f3c7c..ce523ed 100644 --- a/src/utils/circular_array_buffer.jl +++ b/src/utils/circular_array_buffer.jl @@ -164,7 +164,7 @@ end `update!` the last frame of `cb` with data. """ -function RLBase.update!(cb::CircularArrayBuffer{T,N}, data::AbstractArray) where {T,N} +function RLBase.update!(cb::CircularArrayBuffer{T,N}, data) where {T,N} select_last_dim(cb.buffer, _buffer_frame(cb, cb.length)) .= data cb end