From 6c438aea4e50a3a93a32d6932a71a18e529d3f3c Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 24 Nov 2020 01:15:48 +0800 Subject: [PATCH 01/12] bugfixes with trajectory --- src/components/trajectories/trajectory.jl | 9 ++++--- .../trajectories/trajectory_extension.jl | 26 ++++++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 1bd6347..6e39d45 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -44,6 +44,11 @@ Base.pop!(t::Trajectory, s::Symbol) = pop!(t[s]) isfull(t::Trajectory) = all(isfull, t.traces) +# !!! this is a strong assumption, always check it when implementing new Trajectories +# !!! we use `nframes` instead of `length` to avoid some corner cases. +# !!! For example, in `MultiThreadEnv`, the `length(t[:terminal])` is `n_ENV * n_transitions` +Base.length(t::AbstractTrajectory) = nframes(t[:terminal]) + ##### # SharedTrajectory ##### @@ -328,8 +333,6 @@ function VectSARTSATrajectory( ) end -Base.length(t::VectSARTSATrajectory) = length(t[:state]) - ##### # CircularSARTSATrajectory ##### @@ -360,8 +363,6 @@ function CircularSARTSATrajectory(; ) end -Base.length(t::CircularSARTSATrajectory) = length(t[:state]) - ##### # CircularCompactSARTSATrajectory ##### diff --git a/src/components/trajectories/trajectory_extension.jl b/src/components/trajectories/trajectory_extension.jl index e17c954..d7bdf43 100644 --- a/src/components/trajectories/trajectory_extension.jl +++ b/src/components/trajectories/trajectory_extension.jl @@ -39,14 +39,22 @@ end StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) = sample(Random.GLOBAL_RNG, t, sampler) -function StatsBase.sample(rng::AbstractRNG, t::Union{VectSARTSATrajectory, CircularSARTSATrajectory}, sampler::UniformBatchSampler) +function StatsBase.sample( + rng::AbstractRNG, + t::VectSARTSATrajectory, + sampler::UniformBatchSampler, + trace_names=(:state, :action, :reward, :terminal, :next_state, :next_action) +) inds = rand(rng, 1:length(t), sampler.batch_size) - ( - state=Flux.batch(t[:state][inds]), - action=Flux.batch(t[:action][inds]), - reward=Flux.batch(t[:reward][inds]), - terminal=Flux.batch(t[:terminal][inds]), - next_state=Flux.batch(t[:next_state][inds]), - next_action=Flux.batch(t[:next_action][inds]), - ) + NamedTuple{trace_names}(Flux.batch(view(t[x], inds)) for x in trace_names) +end + +function StatsBase.sample( + rng::AbstractRNG, + t::Union{CircularCompactSARTSATrajectory, CircularSARTSATrajectory}, + sampler::UniformBatchSampler, + trace_names=(:state, :action, :reward, :terminal, :next_state, :next_action) +) + inds = rand(rng, 1:length(t), sampler.batch_size) + NamedTuple{trace_names}(convert(Array, consecutive_view(t[x], inds)) for x in trace_names) end From 19ffc2d1ef6f3af52308970ce2554b1c60a4bc30 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 28 Nov 2020 16:03:25 +0800 Subject: [PATCH 02/12] simplify trajectories --- .gitignore | 2 +- Project.toml | 9 +- src/components/agents/abstract_agent.jl | 78 --- src/components/agents/agent.jl | 150 ++--- src/components/agents/agents.jl | 3 +- src/components/agents/base.jl | 65 +++ src/components/agents/dyna_agent.jl | 71 --- src/components/processors.jl | 6 +- .../trajectories/abstract_trajectory.jl | 68 +-- .../trajectories/reservoir_trajectory.jl | 6 +- src/components/trajectories/trajectory.jl | 527 ++---------------- .../trajectories/trajectory_extension.jl | 63 ++- src/core/run.jl | 6 +- src/extensions/ReinforcementLearningBase.jl | 18 - src/utils/circular_array_buffer.jl | 192 ------- src/utils/device.jl | 32 +- src/utils/utils.jl | 1 - test/Project.toml | 10 + test/components/agents.jl | 17 +- test/components/trajectories.jl | 210 +------ test/core/core.jl | 7 +- test/utils/circular_array_buffer.jl | 182 ------ test/utils/utils.jl | 1 - 23 files changed, 311 insertions(+), 1413 deletions(-) delete mode 100644 src/components/agents/abstract_agent.jl create mode 100644 src/components/agents/base.jl delete mode 100644 src/components/agents/dyna_agent.jl delete mode 100644 src/utils/circular_array_buffer.jl create mode 100644 test/Project.toml delete mode 100644 test/utils/circular_array_buffer.jl diff --git a/.gitignore b/.gitignore index ab43d59..cc11c32 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ .DS_Store -/Manifest.toml +Manifest.toml /dev/ /checkpoints/ /logs/ \ No newline at end of file diff --git a/Project.toml b/Project.toml index 7b52b53..0d2bcb1 100644 --- a/Project.toml +++ b/Project.toml @@ -8,10 +8,12 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" @@ -44,10 +46,3 @@ Setfield = "0.6, 0.7" StatsBase = "0.32, 0.33" Zygote = "0.5" julia = "1.4" - -[extras] -ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "ReinforcementLearningEnvironments"] diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl deleted file mode 100644 index 2196f57..0000000 --- a/src/components/agents/abstract_agent.jl +++ /dev/null @@ -1,78 +0,0 @@ -export AbstractAgent, - get_role, - PreExperimentStage, - PostExperimentStage, - PreEpisodeStage, - PostEpisodeStage, - PreActStage, - PostActStage, - PRE_EXPERIMENT_STAGE, - POST_EXPERIMENT_STAGE, - PRE_EPISODE_STAGE, - POST_EPISODE_STAGE, - PRE_ACT_STAGE, - POST_ACT_STAGE, - Training, - Testing - -""" - (agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env) -> action - (agent::AbstractAgent)(stage::AbstractStage, env) - -Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action. -The main difference is that, we divide an experiment into the following stages: - -- `PRE_EXPERIMENT_STAGE` -- `PRE_EPISODE_STAGE` -- `PRE_ACT_STAGE` -- `POST_ACT_STAGE` -- `POST_EPISODE_STAGE` -- `POST_EXPERIMENT_STAGE` - -In each stage, different types of agents may have different behaviors, like updating experience buffer, environment model or policy. -""" -abstract type AbstractAgent end - -function get_role(::AbstractAgent) end - -""" - +-----------------------------------------------------------+ - |Episode | - | | -PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE | POST_EXPERIMENT_STAGE - | | | | | | - v | +-----+ v +-------+ v +-----+ | v - --------------------->+ env +------>+ agent +------->+ env +---> ... ------->...... - | ^ +-----+ +-------+ action +-----+ ^ | - | | | | - | +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ | - | | - | | - +-----------------------------------------------------------+ -""" -abstract type AbstractStage end - -struct PreExperimentStage <: AbstractStage end -struct PostExperimentStage <: AbstractStage end -struct PreEpisodeStage <: AbstractStage end -struct PostEpisodeStage <: AbstractStage end -struct PreActStage <: AbstractStage end -struct PostActStage <: AbstractStage end - -const PRE_EXPERIMENT_STAGE = PreExperimentStage() -const POST_EXPERIMENT_STAGE = PostExperimentStage() -const PRE_EPISODE_STAGE = PreEpisodeStage() -const POST_EPISODE_STAGE = PostEpisodeStage() -const PRE_ACT_STAGE = PreActStage() -const POST_ACT_STAGE = PostActStage() - -(agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env) -function (agent::AbstractAgent)(stage::AbstractStage, env) end - -struct Training{T<:AbstractStage} end -Training(s::T) where {T<:AbstractStage} = Training{T}() -struct Testing{T<:AbstractStage} end -Testing(s::T) where {T<:AbstractStage} = Testing{T}() - -Base.show(io::IO, agent::AbstractAgent) = - AbstractTrees.print_tree(io, StructTree(agent), get(io, :max_depth, 10)) diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index feb9b57..a27589f 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -1,16 +1,14 @@ -export Agent +export Agent, + role -using Flux -using BSON -using JLD -using Setfield +import Functors:functor +using Setfield: @set """ Agent(;kwargs...) -One of the most commonly used [`AbstractAgent`](@ref). - -Generally speaking, it does nothing but update the trajectory and policy appropriately in different stages. +A wrapper of an `AbstractPolicy`. Generally speaking, it does nothing but to +update the trajectory and policy appropriately in different stages and modes. # Keywords & Fields @@ -18,120 +16,78 @@ Generally speaking, it does nothing but update the trajectory and policy appropr - `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment - `role=RLBase.DEFAULT_PLAYER`: used to distinguish different agents """ -Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent +Base.@kwdef struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R,M} <: AbstractPolicy policy::P trajectory::T = DUMMY_TRAJECTORY role::R = RLBase.DEFAULT_PLAYER - is_training::Bool = true + mode::M = TRAIN_MODE end -# avoid polluting trajectory +functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy + +role(agent::Agent) = agent.role +mode(agent::Agent) = agent.mode (agent::Agent)(env) = agent.policy(env) -Flux.functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy - -function save(dir::String, agent::Agent; is_save_trajectory = true) - mkpath(dir) - @info "saving agent to $dir ..." - - t = @elapsed begin - save(joinpath(dir, "policy.bson"), agent.policy) - if is_save_trajectory - JLD.save(joinpath(dir, "trajectory.jld"), "trajectory", agent.trajectory) - else - @warn "trajectory is skipped since you set `is_save_trajectory` to false" - end - BSON.bson( - joinpath(dir, "agent_meta.bson"), - Dict( - :role => agent.role, - :is_training => agent.is_training, - :policy_type => typeof(agent.policy), - ), - ) - end - @info "finished saving agent in $t seconds" -end +(agent::Agent)(stage::AbstractStage, env::AbstractEnv) = agent(env, stage, mode(agent)) -function load(dir::String, ::Type{<:Agent}) - @info "loading agent from $dir" - BSON.@load joinpath(dir, "agent_meta.bson") role is_training policy_type - policy = load(joinpath(dir, "policy.bson"), policy_type) - JLD.@load joinpath(dir, "trajectory.jld") trajectory - Agent(policy, trajectory, role, is_training) +function (agent::Agent)(env::AbstractEnv, stage::AbstractStage, mode::AbstractMode) + update!(agent.trajectory, agent.policy, env, stage, mode) + update!(agent.policy, agent.trajectory, env, stage, mode) end -get_role(agent::Agent) = agent.role +## TrainMode -function Flux.testmode!(agent::Agent, mode = true) - agent.is_training = !mode - testmode!(agent.policy, mode) +function (agent::Agent)(env::AbstractEnv, stage::PreActStage, mode::TrainMode) + action = update!(agent.trajectory, agent.policy, env, stage, mode) + update!(agent.policy, agent.trajectory, env, stage, mode) + action end -(agent::Agent)(stage::AbstractStage, env) = - agent.is_training ? agent(Training(stage), env) : agent(Testing(stage), env) - -(agent::Agent)(::Testing, env) = nothing -(agent::Agent)(::Testing{PreActStage}, env) = agent.policy(env) +## EvalMode -##### -# DummyTrajectory -##### +function (agent::Agent)(env::AbstractEnv, stage::PreActStage, mode::EvalMode) + update!(agent.trajectory, agent.policy, env, stage, mode) +end -(agent::Agent{<:AbstractPolicy,<:DummyTrajectory})(stage::AbstractStage, env) = nothing -(agent::Agent{<:AbstractPolicy,<:DummyTrajectory})(stage::PreActStage, env) = - agent.policy(env) +## TestMode -##### -# default behavior -##### +(agent::Agent)(::AbstractEnv, ::AbstractStage, ::TestMode) = nothing +(agent::Agent)(env::AbstractEnv, ::PreActStage, ::TestMode) = agent.policy(env) -function (agent::Agent)(::Training{PreEpisodeStage}, env) - if nframes(agent.trajectory[:full_state]) > 0 - pop!(agent.trajectory, :full_state) - end - if nframes(agent.trajectory[:full_action]) > 0 - pop!(agent.trajectory, :full_action) - end - if ActionStyle(env) === FULL_ACTION_SET && - nframes(agent.trajectory[:full_legal_actions_mask]) > 0 - pop!(agent.trajectory, :full_legal_actions_mask) - end -end +## update trajectory -function (agent::Agent)(::Training{PreActStage}, env) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - if ActionStyle(env) === FULL_ACTION_SET - push!(agent.trajectory; legal_actions_mask = get_legal_actions_mask(env)) +function RLBase.update!(trajectory::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::PreEpisodeStage, ::AbstractMode) + if length(trajectory) > 0 + pop!(trajectory[:state]) + pop!(trajectory[:action]) + haskey(trajectory, :legal_actions_mask) && pop!(trajectory[:legal_actions_mask]) end - update!(agent.policy, agent.trajectory) - action end -function (agent::Agent)(::Training{PostActStage}, env) - push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) - nothing +function RLBase.update!(trajectory::AbstractTrajectory, policy::AbstractPolicy, env::AbstractEnv, ::PostEpisodeStage, ::AbstractMode) + action = policy(env) + push!(trajectory[:state], get_state(env)) + push!(trajectory[:action], action) + haskey(trajectory, :legal_actions_mask) && push!(trajectory[:legal_actions_mask], get_legal_actions_mask(env)) end -function (agent::Agent)(::Training{PostEpisodeStage}, env) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - if ActionStyle(env) === FULL_ACTION_SET - push!(agent.trajectory; legal_actions_mask = get_legal_actions_mask(env)) - end - update!(agent.policy, agent.trajectory) +function RLBase.update!(trajectory::CircularArraySARTTrajectory, policy::AbstractPolicy, env::AbstractEnv, ::PreActStage, ::AbstractMode) + action = policy(env) + push!(trajectory[:state], get_state(env)) + push!(trajectory[:action], action) + haskey(trajectory, :legal_actions_mask) && push!(trajectory[:legal_actions_mask], get_legal_actions_mask(env)) action end -##### -# EpisodicCompactSARTSATrajectory -##### -function (agent::Agent{<:AbstractPolicy,<:EpisodicTrajectory})( - ::Training{PreEpisodeStage}, - env, -) - empty!(agent.trajectory) - nothing +function RLBase.update!(trajectory::AbstractTrajectory, ::AbstractPolicy, env::AbstractEnv, ::PostActStage, ::AbstractMode) + push!(trajectory[:reward], get_reward(env)) + push!(trajectory[:terminal], get_terminal(env)) end + +## update policy + +RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage, ::AbstractMode) = nothing + +RLBase.update!(policy::AbstractPolicy, trajectory::AbstractTrajectory, ::AbstractEnv, ::PreActStage, ::AbstractMode) = update!(policy, trajectory) diff --git a/src/components/agents/agents.jl b/src/components/agents/agents.jl index 3d063e0..3edb072 100644 --- a/src/components/agents/agents.jl +++ b/src/components/agents/agents.jl @@ -1,3 +1,2 @@ -include("abstract_agent.jl") +include("base.jl") include("agent.jl") -include("dyna_agent.jl") diff --git a/src/components/agents/base.jl b/src/components/agents/base.jl new file mode 100644 index 0000000..0a52e04 --- /dev/null +++ b/src/components/agents/base.jl @@ -0,0 +1,65 @@ +export AbstractStage, + PreExperimentStage, + PostExperimentStage, + PreEpisodeStage, + PostEpisodeStage, + PreActStage, + PostActStage, + PRE_EXPERIMENT_STAGE, + POST_EXPERIMENT_STAGE, + PRE_EPISODE_STAGE, + POST_EPISODE_STAGE, + PRE_ACT_STAGE, + POST_ACT_STAGE, + set_mode!, + mode, + AbstractMode, + TrainMode, + TRAIN_MODE, + TestMode, + TEST_MODE, + EvalMode, + EVAL_MODE + +##### +# Stage +##### + +abstract type AbstractStage end + +struct PreExperimentStage <: AbstractStage end +const PRE_EXPERIMENT_STAGE = PreExperimentStage() + +struct PostExperimentStage <: AbstractStage end +const POST_EXPERIMENT_STAGE = PostExperimentStage() + +struct PreEpisodeStage <: AbstractStage end +const PRE_EPISODE_STAGE = PreEpisodeStage() + +struct PostEpisodeStage <: AbstractStage end +const POST_EPISODE_STAGE = PostEpisodeStage() + +struct PreActStage <: AbstractStage end +const PRE_ACT_STAGE = PreActStage() + +struct PostActStage <: AbstractStage end +const POST_ACT_STAGE = PostActStage() + +##### +# Modes +##### + +abstract type AbstractMode end + +struct TrainMode <: AbstractMode end +const TRAIN_MODE = TrainMode() + +struct EvalMode <: AbstractMode end +const EVAL_MODE = EvalMode() + +struct TestMode <: AbstractMode end +const TEST_MODE = TestMode() + +function mode end + +function set_mode! end \ No newline at end of file diff --git a/src/components/agents/dyna_agent.jl b/src/components/agents/dyna_agent.jl deleted file mode 100644 index 57f3f2a..0000000 --- a/src/components/agents/dyna_agent.jl +++ /dev/null @@ -1,71 +0,0 @@ -export DynaAgent - -""" - DynaAgent(;kwargs...) - -`DynaAgent` is first introduced in: *Sutton, Richard S. "Dyna, an integrated architecture for learning, planning, and reacting." ACM Sigart Bulletin 2.4 (1991): 160-163.* - -# Keywords & Fields - -- `policy`::[`AbstractPolicy`](@ref): the policy to use -- `model`::[`AbstractEnvironmentModel`](@ref): describe the environment to interact with -- `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between agent and environment -- `role=:DEFAULT`: used to distinguish different agents -- `plan_step::Int=10`: the count of planning steps - -The main difference between [`DynaAgent`](@ref) and [`Agent`](@ref) is that an environment model is involved. It is best described in the book: *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.* - -![](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/raw/master/docs/src/assets/img/RL_book_fig_8_1.png) -![](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/raw/master/docs/src/assets/img/RL_book_fig_8_2.png) -""" -Base.@kwdef struct DynaAgent{ - P<:AbstractPolicy, - B<:AbstractTrajectory, - M<:AbstractEnvironmentModel, - R, -} <: AbstractAgent - policy::P - model::M - trajectory::B - role::R = :DEFAULT_PLAYER - plan_step::Int = 10 -end - -get_role(agent::DynaAgent) = agent.role - -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(::PreEpisodeStage, env) - empty!(agent.trajectory) - nothing -end - -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(::PreActStage, env) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - update!(agent.model, agent.trajectory, agent.policy) # model learning - update!(agent.policy, agent.trajectory) # direct learning - update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning - action -end - -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(::PostActStage, env) - push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) - nothing -end - -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(::PostEpisodeStage, env) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - update!(agent.model, agent.trajectory, agent.policy) # model learning - update!(agent.policy, agent.trajectory) # direct learning - update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning - action -end - -"By default, only use trajectory to update model" -RLBase.update!(model::AbstractEnvironmentModel, t::AbstractTrajectory, π::AbstractPolicy) = - update!(model, t) - -# function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory) -# transitions = extract_experience(buffer, model) -# isnothing(transitions) || update!(model, transitions) -# end diff --git a/src/components/processors.jl b/src/components/processors.jl index dc030b5..c414758 100644 --- a/src/components/processors.jl +++ b/src/components/processors.jl @@ -1,6 +1,8 @@ export StackFrames, ResizeImage using ImageTransformations: imresize! +import CircularArrayBuffers +using CircularArrayBuffers:CircularArrayBuffer """ ResizeImage(img::Array{T, N}) @@ -34,7 +36,7 @@ StackFrames(d::Int...) = StackFrames(Float32, d...) function StackFrames(::Type{T}, d::Vararg{Int,N}) where {T,N} p = StackFrames(CircularArrayBuffer{T}(d...)) - for _ in 1:capacity(p.buffer) + for _ in 1:CircularArrayBuffers.capacity(p.buffer) push!(p.buffer, zeros(T, size(p.buffer)[1:N-1])) end p @@ -55,7 +57,7 @@ end function RLBase.reset!(p::StackFrames{T,N}) where {T,N} empty!(p.buffer) - for _ in 1:capacity(p.buffer) + for _ in 1:CircularArrayBuffers.capacity(p.buffer) push!(p.buffer, zeros(T, size(p.buffer)[1:N-1])) end p diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl index d69b51f..8b1b1ab 100644 --- a/src/components/trajectories/abstract_trajectory.jl +++ b/src/components/trajectories/abstract_trajectory.jl @@ -1,57 +1,63 @@ -export AbstractTrajectory +export AbstractTrajectory, + SART, + SARTSA """ AbstractTrajectory -A trace is used to record some useful information +A trajectory is used to record some useful information during the interactions between agents and environments. +It behaves similar to a `NamedTuple` except that we extend it +with some optional methods. Required Methods: -- `Base.haskey(t::AbstractTrajectory, s::Symbol)` -- `Base.getindex(t::AbstractTrajectory, s::Symbol)` -- `Base.keys(t::AbstractTrajectory)` -- `Base.push!(t::AbstractTrajectory, kv::Pair{Symbol})` -- `Base.pop!(t::AbstractTrajectory, s::Symbol)` -- `Base.empty!(t::AbstractTrajectory)` +- `Base.getindex` +- `Base.keys` Optional Methods: -- `isfull` - +- `Base.length` +- `Base.isempty` +- `Base.empty!` +- `Base.haskey` +- `Base.push!` +- `Base.pop!` """ abstract type AbstractTrajectory end -function Base.push!(t::AbstractTrajectory; kwargs...) - for kv in kwargs - push!(t, kv) +Base.haskey(t::AbstractTrajectory, s::Symbol) = s in keys(t) +Base.isempty(t::AbstractTrajectory) = all(k -> isempty(t[k]), keys(t)) + +function Base.empty!(t::AbstractTrajectory) + for k in keys(t) + empty!(t[k]) end end -""" - Base.pop!(t::AbstractTrajectory, s::Symbol...) - -`pop!` out one element of the traces specified in `s` -""" -function Base.pop!(t::AbstractTrajectory, s::Tuple{Vararg{Symbol}}) - NamedTuple{s}(pop!(t, x) for x in s) +function Base.push!(t::AbstractTrajectory;kwargs...) + for (k,v) in kwargs + push!(t[k], v) + end end -Base.pop!(t::AbstractTrajectory) = pop!(t, keys(t)) +function Base.pop!(t::AbstractTrajectory) + for k in keys(t) + pop!(t[k]) + end +end -function Base.empty!(t::AbstractTrajectory) - for s in keys(t) - empty!(t[s]) +function Base.show(io::IO, t::AbstractTrajectory) + println(io, "Trajectory of $(length(keys(t))) traces:") + for k in keys(t) + show(io, k) + println(" $(summary(t[k]))") end end ##### -# patch code +# Common Keys ##### -# avoid showing the inner structure -function AbstractTrees.children(t::StructTree{<:AbstractTrajectory}) - Tuple(k => StructTree(t.x[k]) for k in keys(t.x)) -end - -@deprecate get_trace(t::AbstractTrajectory, s::Symbol) t[s] +const SART = (:state, :action, :reward, :terminal) +const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) \ No newline at end of file diff --git a/src/components/trajectories/reservoir_trajectory.jl b/src/components/trajectories/reservoir_trajectory.jl index 2dc8f3f..53dd933 100644 --- a/src/components/trajectories/reservoir_trajectory.jl +++ b/src/components/trajectories/reservoir_trajectory.jl @@ -15,12 +15,12 @@ end Base.length(x::ReservoirTrajectory) = length(x.buffer[1]) function ReservoirTrajectory( - capacity, - kw::Pair{Symbol,DataType}...; + capacity; n = 0, rng = Random.GLOBAL_RNG, + kw... ) - buffer = Trajectory(; (s => Vector{t}() for (s, t) in kw)...) + buffer = VectorTrajectory(;kw...) ReservoirTrajectory(buffer, n, capacity, rng) end diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 47be052..da782fc 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -1,22 +1,16 @@ export Trajectory, - SharedTrajectory, - EpisodicTrajectory, - CombinedTrajectory, - CircularCompactSATrajectory, - VectCompactSATrajectory, - ElasticCompactSATrajectory, - CircularCompactSALTrajectory, - CircularCompactSARTSATrajectory, - VectCompactSARTSATrajectory, - ElasticCompactSARTSATrajectory, - CircularCompactPSARTSATrajectory, - CircularCompactSALRTSALTrajectory, - CircularCompactPSALRTSALTrajectory, - VectSARTSATrajectory, - CircularSARTSATrajectory + DUMMY_TRAJECTORY, + DummyTrajectory, + CircularArrayTrajectory, + CircularVectorTrajectory, + CircularArraySARTTrajectory, + CircularVectorSARTTrajectory, + CircularVectorSARTSATrajectory, + VectorTrajectory using MacroTools: @forward using ElasticArrays +using CircularArrayBuffers: CircularArrayBuffer, CircularVectorBuffer ##### # Trajectory @@ -25,7 +19,7 @@ using ElasticArrays """ Trajectory(;[trace_name=trace_container]...) -Simply a wrapper of `NamedTuple`. +A simple wrapper of `NamedTuple`. Define our own type here to avoid type piracy with `NamedTuple` """ struct Trajectory{T} <: AbstractTrajectory @@ -34,496 +28,97 @@ end Trajectory(; kwargs...) = Trajectory(kwargs.data) -const DUMMY_TRAJECTORY = Trajectory() -const DummyTrajectory = typeof(DUMMY_TRAJECTORY) - -@forward Trajectory.traces Base.keys, Base.haskey, Base.getindex - -Base.push!(t::Trajectory, kv::Pair{Symbol}) = push!(t[first(kv)], last(kv)) -Base.pop!(t::Trajectory, s::Symbol) = pop!(t[s]) +@forward Trajectory.traces Base.getindex, Base.keys -isfull(t::Trajectory) = all(isfull, t.traces) +Base.merge(a::Trajectory, b::Trajectory) = Trajectory(merge(a.traces, b.traces)) +Base.merge(a::Trajectory, b::NamedTuple) = Trajectory(merge(a.traces, b)) +Base.merge(a::NamedTuple, b::Trajectory) = Trajectory(merge(a, b.traces)) -# !!! this is a strong assumption, always check it when implementing new Trajectories -# !!! we use `nframes` instead of `length` to avoid some corner cases. -# !!! For example, in `MultiThreadEnv`, the `length(t[:terminal])` is `n_ENV * n_transitions` -Base.length(t::AbstractTrajectory) = nframes(t[:terminal]) - -##### -# SharedTrajectory ##### -struct SharedTrajectoryMeta - start_shift::Int - end_shift::Int -end - -""" - SharedTrajectory(trace_container, meta::NamedTuple{([trace_name::Symbol],...), Tuple{[SharedTrajectoryMeta]...}}) - -Create multiple traces sharing the same underlying container. -""" -struct SharedTrajectory{X,M} <: AbstractTrajectory - x::X - meta::M -end - -""" - SharedTrajectory(trace_container, s::Symbol) - -Automatically create the following three traces: - -- `s`, share the data in `trace_container` in the range of `1:end-1` -- `s` with a prefix of `next_`, share the data in `trace_container` in the range of `2:end` -- `s` with a prefix of `full_`, a view of `trace_container` -""" -function SharedTrajectory(x, s::Symbol) - SharedTrajectory( - x, - (; - s => SharedTrajectoryMeta(1, -1), - Symbol(:next_, s) => SharedTrajectoryMeta(2, 0), - Symbol(:full_, s) => SharedTrajectoryMeta(1, 0), - ), - ) -end - -@forward SharedTrajectory.meta Base.keys, Base.haskey - -function Base.getindex(t::SharedTrajectory, s::Symbol) - m = t.meta[s] - select_last_dim(t.x, m.start_shift:(nframes(t.x)+m.end_shift)) -end - -Base.push!(t::SharedTrajectory, kv::Pair{Symbol}) = push!(t.x, last(kv)) -Base.empty!(t::SharedTrajectory) = empty!(t.x) -Base.pop!(t::SharedTrajectory, s::Symbol) = pop!(t.x) - -function Base.pop!(t::SharedTrajectory) - s = first(keys(t)) - (; s => pop!(t.x)) -end - -isfull(t::SharedTrajectory) = isfull(t.x) - -##### -# EpisodicTrajectory -##### - -""" - EpisodicTrajectory(traces::T, flag_trace=:terminal) - -Assuming that the `flag_trace` is in `traces` and it's an `AbstractVector{Bool}`, -meaning whether an environment reaches terminal or not. The last element in -`flag_trace` will be used to determine whether the whole trace is full or not. -""" -struct EpisodicTrajectory{T,flag_trace} <: AbstractTrajectory - traces::T -end - -EpisodicTrajectory(traces::T, flag_trace = :terminal) where {T} = - EpisodicTrajectory{T,flag_trace}(traces) - -@forward EpisodicTrajectory.traces Base.keys, -Base.haskey, -Base.getindex, -Base.push!, -Base.pop!, -Base.empty! - -function isfull(t::EpisodicTrajectory{<:Any,F}) where {F} - x = t.traces[F] - (nframes(x) > 0) && select_last_frame(x) -end - -##### -# CombinedTrajectory -##### - -""" - CombinedTrajectory(t1::AbstractTrajectory, t2::AbstractTrajectory) -""" -struct CombinedTrajectory{T1,T2} <: AbstractTrajectory - t1::T1 - t2::T2 -end - -Base.haskey(t::CombinedTrajectory, s::Symbol) = haskey(t.t1, s) || haskey(t.t2, s) -Base.getindex(t::CombinedTrajectory, s::Symbol) = - if haskey(t.t1, s) - getindex(t.t1, s) - elseif haskey(t.t2, s) - getindex(t.t2, s) - else - throw(ArgumentError("unknown key: $s")) - end - -Base.keys(t::CombinedTrajectory) = (keys(t.t1)..., keys(t.t2)...) - -Base.push!(t::CombinedTrajectory, kv::Pair{Symbol}) = - if haskey(t.t1, first(kv)) - push!(t.t1, kv) - elseif haskey(t.t2, first(kv)) - push!(t.t2, kv) - else - throw(ArgumentError("unknown kv: $kv")) - end - -Base.pop!(t::CombinedTrajectory, s::Symbol) = - if haskey(t.t1, s) - pop!(t.t1, s) - elseif haskey(t.t2, s) - pop!(t.t2, s) - else - throw(ArgumentError("unknown key: $s")) - end - -Base.pop!(t::CombinedTrajectory) = merge(pop!(t.t1), pop!(t.t2)) - -function Base.empty!(t::CombinedTrajectory) - empty!(t.t1) - empty!(t.t2) -end - -isfull(t::CombinedTrajectory) = isfull(t.t1) && isfull(t.t2) +const DUMMY_TRAJECTORY = Trajectory() +const DummyTrajectory = typeof(DUMMY_TRAJECTORY) ##### -# VectCompactSATrajectory -##### - -const VectCompactSATrajectory = CombinedTrajectory{ - <:SharedTrajectory{<:Vector,<:NamedTuple{(:state, :next_state, :full_state)}}, - <:SharedTrajectory{<:Vector,<:NamedTuple{(:action, :next_action, :full_action)}}, -} -function VectCompactSATrajectory(; state_type = Int, action_type = Int) - CombinedTrajectory( - SharedTrajectory(Vector{state_type}(), :state), - SharedTrajectory(Vector{action_type}(), :action), - ) +function CircularArrayTrajectory(;capacity, kwargs...) + Trajectory(map(kwargs.data) do x + CircularArrayBuffer{eltype(first(x))}(last(x)..., capacity) + end) end -##### -# CircularCompactSATrajectory -##### - -const CircularCompactSATrajectory = CombinedTrajectory{ - <:SharedTrajectory{ - <:CircularArrayBuffer, - <:NamedTuple{(:state, :next_state, :full_state)}, - }, - <:SharedTrajectory{ - <:CircularArrayBuffer, - <:NamedTuple{(:action, :next_action, :full_action)}, - }, -} - -function CircularCompactSATrajectory(; - capacity, - state_type = Int, - state_size = (), - action_type = Int, - action_size = (), -) - CombinedTrajectory( - SharedTrajectory( - CircularArrayBuffer{state_type}(state_size..., capacity + 1), - :state, - ), - SharedTrajectory( - CircularArrayBuffer{action_type}(action_size..., capacity + 1), - :action, - ), - ) +function CircularVectorTrajectory(;capacity, kwargs...) + Trajectory(map(kwargs.data) do x + CircularVectorBuffer{x}(capacity) + end) end -##### -# ElasticCompactSATrajectory ##### -const ElasticCompactSATrajectory = CombinedTrajectory{ - <:SharedTrajectory{<:ElasticArray,<:NamedTuple{(:state, :next_state, :full_state)}}, - <:SharedTrajectory{<:ElasticArray,<:NamedTuple{(:action, :next_action, :full_action)}}, -} +const CircularArraySARTTrajectory = Trajectory{ + <:NamedTuple{ + (:state, :action, :reward, :terminal), + <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:CircularArrayBuffer, <:CircularArrayBuffer}}} -function ElasticCompactSATrajectory(; - state_type = Int, - state_size = (), - action_type = Int, - action_size = (), +CircularArraySARTTrajectory(;capacity::Int, state=Int=>(), action=Int=>(), reward=Float32=>(), terminal=Bool=>()) = merge( + CircularArrayTrajectory(;capacity=capacity+1, state=state, action=action), + CircularArrayTrajectory(;capacity=capacity, reward=reward, terminal=terminal), ) - CombinedTrajectory( - SharedTrajectory(ElasticArray{state_type}(undef, state_size..., 0), :state), - SharedTrajectory(ElasticArray{action_type}(undef, action_size..., 0), :action), - ) -end - -##### -# CircularCompactSALTrajectory -##### - -const CircularCompactSALTrajectory = CombinedTrajectory{ - <:SharedTrajectory{ - <:CircularArrayBuffer, - <:NamedTuple{ - (:legal_actions_mask, :next_legal_actions_mask, :full_legal_actions_mask), - }, - }, - <:CircularCompactSATrajectory, -} +const CircularArraySALRTTrajectory = Trajectory{ + <:NamedTuple{ + (:state, :action, :legal_actions_mask, :reward, :terminal), + <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:CircularArrayBuffer, <:CircularArrayBuffer, <:CircularArrayBuffer}}} -function CircularCompactSALTrajectory(; - capacity, - legal_actions_mask_size, - legal_actions_mask_type = Bool, - kw..., +CircularArraySALRTTrajectory(;capacity::Int, state=Int=>(), action=Int=>(), legal_actions_mask, reward=Float32=>(), terminal=Bool=>()) = merge( + CircularArrayTrajectory(;capacity=capacity+1, state=state, action=action, legal_actions_mask=legal_actions_mask), + CircularArrayTrajectory(;capacity=capacity, reward=reward, terminal=terminal), ) - CombinedTrajectory( - SharedTrajectory( - CircularArrayBuffer{legal_actions_mask_type}( - legal_actions_mask_size..., - capacity + 1, - ), - :legal_actions_mask, - ), - CircularCompactSATrajectory(; capacity = capacity, kw...), - ) -end ##### -# VectCompactSARTSATrajectory -##### - -const VectCompactSARTSATrajectory = CombinedTrajectory{ - <:Trajectory{<:NamedTuple{(:reward, :terminal),<:Tuple{<:Vector,<:Vector}}}, - <:VectCompactSATrajectory, -} -function VectCompactSARTSATrajectory(; reward_type = Float32, terminal_type = Bool, kw...) - CombinedTrajectory( - Trajectory(reward = Vector{reward_type}(), terminal = Vector{terminal_type}()), - VectCompactSATrajectory(; kw...), - ) -end - -##### -# VectSARTSATrajectory -##### - -const VectSARTSATrajectory = Trajectory{ +const CircularVectorSARTTrajectory = Trajectory{ <:NamedTuple{ - (:state, :action, :reward, :terminal, :next_state, :next_action), - <:Tuple{<:Vector,<:Vector,<:Vector,<:Vector,<:Vector,<:Vector}, - }, -} + (:state, :action, :reward, :terminal), + <:Tuple{<:CircularVectorBuffer, <:CircularVectorBuffer, <:CircularVectorBuffer, <:CircularVectorBuffer}}} -function VectSARTSATrajectory(; - state_type = Int, - action_type = Int, - reward_type = Float32, - terminal_type = Bool, - next_state_type = state_type, - next_action_type = action_type, +CircularVectorSARTTrajectory(;capacity::Int, state=Int, action=Int, reward=Float32, terminal=Bool) = merge( + CircularVectorTrajectory(;capacity=capacity+1, state=state, action=action), + CircularVectorTrajectory(;capacity=capacity, reward=reward, terminal=terminal), ) - Trajectory(; - state = Vector{state_type}(), - action = Vector{action_type}(), - reward = Vector{reward_type}(), - terminal = Vector{terminal_type}(), - next_state = Vector{next_state_type}(), - next_action = Vector{next_action_type}(), - ) -end -##### -# CircularSARTSATrajectory ##### -const CircularSARTSATrajectory = Trajectory{ +const CircularVectorSARTSATrajectory = Trajectory{ <:NamedTuple{ (:state, :action, :reward, :terminal, :next_state, :next_action), - <:Tuple{ - <:CircularArrayBuffer, - <:CircularArrayBuffer, - <:CircularArrayBuffer, - <:CircularArrayBuffer, - <:CircularArrayBuffer, - <:CircularArrayBuffer, - }, - }, -} - -function CircularSARTSATrajectory(; - capacity, - state_type = Float32, - state_size = (), - action_type = Int, - action_size = (), - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), -) - Trajectory( - state = CircularArrayBuffer{state_type}(state_size..., capacity), - action = CircularArrayBuffer{action_type}(action_size..., capacity), - reward = CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - next_state = CircularArrayBuffer{state_type}(state_size..., capacity), - next_action = CircularArrayBuffer{action_type}(action_size..., capacity), - ) -end - -##### -# CircularCompactSARTSATrajectory -##### - -const CircularCompactSARTSATrajectory = CombinedTrajectory{ - <:Trajectory{ - <:NamedTuple{ - (:reward, :terminal), - <:Tuple{<:CircularArrayBuffer,<:CircularArrayBuffer}, - }, - }, - <:CircularCompactSATrajectory, -} - -function CircularCompactSARTSATrajectory(; - capacity, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw..., -) - CombinedTrajectory( - Trajectory( - reward = CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - ), - CircularCompactSATrajectory(; capacity = capacity, kw...), - ) -end - -##### -# ElasticCompactSARTSATrajectory -##### + <:Tuple{<:CircularVectorBuffer, <:CircularVectorBuffer, <:CircularVectorBuffer, <:CircularVectorBuffer, <:CircularVectorBuffer, <:CircularVectorBuffer}}} -const ElasticCompactSARTSATrajectory = CombinedTrajectory{ - <:Trajectory{<:NamedTuple{(:reward, :terminal),<:Tuple{<:ElasticArray,<:ElasticArray}}}, - <:ElasticCompactSATrajectory, -} +CircularVectorSARTSATrajectory(;capacity::Int, state=Int, action=Int, reward=Float32, terminal=Bool, next_state=state, next_action=action) = CircularVectorTrajectory(;capacity=capacity, state=state, action=action, reward=reward,terminal=terminal,next_state=next_state, next_action=next_action), -function ElasticCompactSARTSATrajectory(; - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw..., -) - CombinedTrajectory( - Trajectory( - reward = ElasticArray{reward_type}(undef, reward_size..., 0), - terminal = ElasticArray{terminal_type}(undef, terminal_size..., 0), - ), - ElasticCompactSATrajectory(; kw...), - ) -end - -##### -# CircularCompactSALRTSALTrajectory ##### -const CircularCompactSALRTSALTrajectory = CombinedTrajectory{ - <:Trajectory{ - <:NamedTuple{ - (:reward, :terminal), - <:Tuple{<:CircularArrayBuffer,<:CircularArrayBuffer}, - }, - }, - <:CircularCompactSALTrajectory, -} - -function CircularCompactSALRTSALTrajectory(; - capacity, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw..., -) - CombinedTrajectory( - Trajectory( - reward = CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - ), - CircularCompactSALTrajectory(; capacity = capacity, kw...), - ) +function ElasticArrayTrajectory(;kwargs...) + Trajectory(map(kwargs.data) do x + ElasticArray{eltype(first(x))}(undef, last(x)..., 0) + end) end ##### -# CircularCompactPSARTSATrajectory +# VectorTrajectory ##### -const CircularCompactPSARTSATrajectory = CombinedTrajectory{ - <:Trajectory{ - <:NamedTuple{ - (:reward, :terminal, :priority), - <:Tuple{<:CircularArrayBuffer,<:CircularArrayBuffer,<:SumTree}, - }, - }, - <:CircularCompactSATrajectory, -} - -function CircularCompactPSARTSATrajectory(; - capacity, - priority_type = Float32, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw..., -) - CombinedTrajectory( - Trajectory( - reward = CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - priority = SumTree(priority_type, capacity), - ), - CircularCompactSATrajectory(; capacity = capacity, kw...), - ) +function VectorTrajectory(;kwargs...) + Trajectory(map(kwargs.data) do x + Vector{x}() + end) end ##### -# CircularCompactPSALRTSALTrajectory +# Common ##### -const CircularCompactPSALRTSALTrajectory = CombinedTrajectory{ - <:Trajectory{ - <:NamedTuple{ - (:reward, :terminal, :priority), - <:Tuple{<:CircularArrayBuffer,<:CircularArrayBuffer,<:SumTree}, - }, - }, - <:CircularCompactSALTrajectory, -} - -function CircularCompactPSALRTSALTrajectory(; - capacity, - priority_type = Float32, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw..., -) - CombinedTrajectory( - Trajectory( - reward = CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - priority = SumTree(priority_type, capacity), - ), - CircularCompactSALTrajectory(; capacity = capacity, kw...), - ) -end +function Base.length(t::Union{<:CircularArraySARTTrajectory,<:CircularVectorSARTSATrajectory}) + x = t[:terminal] + size(x, ndims(x)) +end \ No newline at end of file diff --git a/src/components/trajectories/trajectory_extension.jl b/src/components/trajectories/trajectory_extension.jl index c37fbda..8e67fd8 100644 --- a/src/components/trajectories/trajectory_extension.jl +++ b/src/components/trajectories/trajectory_extension.jl @@ -13,22 +13,18 @@ Base.@kwdef struct NStepInserter <: AbstractInserter end function Base.push!( - t::CircularSARTSATrajectory, - 𝕥::CircularCompactSARTSATrajectory, - adder::NStepInserter, + t::CircularVectorSARTSATrajectory, + 𝕥::CircularArraySARTTrajectory, + inserter::NStepInserter, ) - N = length(𝕥[:terminal]) - n = adder.n + N = length(𝕥) + n = inserter.n for i in 1:(N-n+1) - push!( - t; - state = select_last_dim(𝕥[:state], i), - action = select_last_dim(𝕥[:action], i), - reward = select_last_dim(𝕥[:reward], i), - terminal = select_last_dim(𝕥[:terminal], i), - next_state = select_last_dim(𝕥[:next_state], i + n - 1), - next_action = select_last_dim(𝕥[:next_action], i + n - 1), - ) + for k in SART + push!(t[k], select_last_dim(𝕥[k], i)) + end + push!(t[:next_state], select_last_dim(𝕥[:state], i+n)) + push!(t[:next_action], select_last_dim(𝕥[:action], i+n)) end end @@ -42,25 +38,28 @@ struct UniformBatchSampler <: AbstractSampler batch_size::Int end -StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) = - sample(Random.GLOBAL_RNG, t, sampler) +""" + sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=keys(trajectory)]) -function StatsBase.sample( - rng::AbstractRNG, - t::VectSARTSATrajectory, - sampler::UniformBatchSampler, - trace_names=(:state, :action, :reward, :terminal, :next_state, :next_action) -) - inds = rand(rng, 1:length(t), sampler.batch_size) - NamedTuple{trace_names}(Flux.batch(view(t[x], inds)) for x in trace_names) +!!! note + Here we return a copy instead of a view: + 1. Each sample is independent of the original `trajectory` so that `trajectory` can be updated async. + 2. [Copy is not always so bad](https://docs.julialang.org/en/v1/manual/performance-tips/#Copying-data-is-not-always-bad). +""" +function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler, traces=keys(t)) + sample(Random.GLOBAL_RNG, t, sampler, traces) end -function StatsBase.sample( - rng::AbstractRNG, - t::Union{CircularCompactSARTSATrajectory, CircularSARTSATrajectory}, - sampler::UniformBatchSampler, - trace_names=(:state, :action, :reward, :terminal, :next_state, :next_action) -) - inds = rand(rng, 1:length(t), sampler.batch_size) - NamedTuple{trace_names}(convert(Array, consecutive_view(t[x], inds)) for x in trace_names) +function StatsBase.sample(rng::AbstractRNG, t::CircularVectorSARTSATrajectory, s::UniformBatchSampler, traces) + inds = rand(rng, 1:length(t), s.batch_size) + map(traces) do x + Flux.batch(view(t[x], inds)) + end end + +function StatsBase.sample(rng::AbstractRNG, t::CircularArraySARTTrajectory, s::UniformBatchSampler, traces) + inds = rand(rng, 1:length(t), s.batch_size) + map(traces) do x + convert(Array, consecutive_view(t[x], inds)) + end +end \ No newline at end of file diff --git a/src/core/run.jl b/src/core/run.jl index 39ef1f3..bc72154 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -9,7 +9,7 @@ _run(agent, env, args...) = _run(DynamicStyle(env), NumAgentStyle(env), agent, e function _run( ::Sequential, ::SingleAgent, - agent::AbstractAgent, + agent::Agent, env::AbstractEnv, stop_condition, hook::AbstractHook = EmptyHook(), @@ -49,7 +49,7 @@ end function _run( ::Sequential, ::SingleAgent, - agent::AbstractAgent, + agent::Agent, env::MultiThreadEnv, stop_condition, hook::AbstractHook = EmptyHook(), @@ -75,7 +75,7 @@ end function _run( ::Sequential, ::MultiAgent, - agents::Tuple{Vararg{<:AbstractAgent}}, + agents::Tuple{Vararg{<:Agent}}, env::AbstractEnv, stop_condition, hooks = [EmptyHook() for _ in agents], diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index 6032434..faac067 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -9,14 +9,6 @@ RLBase.update!(p::RandomPolicy, x) = nothing Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s) -# avoid fallback silently -Flux.testmode!(p::AbstractPolicy, mode = true) = - @error "someone forgets to implement this method!!!" - -function Flux.testmode!(p::RandomStartPolicy, mode = true) - testmode!(p.policy, mode) -end - Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10)) @@ -26,13 +18,3 @@ AbstractTrees.printnode(io::IO, t::StructTree{<:AbstractEnv}) = print( io, "$(RLBase.get_name(t.x)): $(join([f(t.x) for f in RLBase.get_env_traits()], ","))", ) - -function save(f::String, p::AbstractPolicy) - policy = cpu(p) - BSON.@save f policy -end - -function load(f::String, ::Type{<:AbstractPolicy}) - BSON.@load f policy - policy -end diff --git a/src/utils/circular_array_buffer.jl b/src/utils/circular_array_buffer.jl deleted file mode 100644 index a094201..0000000 --- a/src/utils/circular_array_buffer.jl +++ /dev/null @@ -1,192 +0,0 @@ -export CircularArrayBuffer, capacity, isfull - -using ReinforcementLearningBase - -""" - CircularArrayBuffer{T}(d::Integer...) -> CircularArrayBuffer{T, N} - -`CircularArrayBuffer` uses a `N`-dimension `Array` of size `d` to serve as a buffer for -`N-1`-dimension `Array`s with the same size. - -# Examples - -```julia-repl -julia> b = CircularArrayBuffer{Float64}(2, 2, 3) -2×2×0 CircularArrayBuffer{Float64,3} - -julia> capacity(b) -3 - -julia> length(b) -0 - -julia> push!(b, [1. 1.; 2. 2.]) -2×2×1 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 1.0 1.0 - 2.0 2.0 - -julia> b -2×2×1 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 1.0 1.0 - 2.0 2.0 - -julia> length(b) -4 - -julia> nframes(b) -1 - -julia> ones(2,2) -2×2 Array{Float64,2}: - 1.0 1.0 - 1.0 1.0 - -julia> 3 .* ones(2,2) -2×2 Array{Float64,2}: - 3.0 3.0 - 3.0 3.0 - -julia> 3 * ones(2,2) -2×2 Array{Float64,2}: - 3.0 3.0 - 3.0 3.0 - -julia> b = CircularArrayBuffer{Float64}(2, 2, 3) -2×2×0 CircularArrayBuffer{Float64,3} - -julia> capacity(b) -3 - -julia> nframes(b) -0 - -julia> push!(b, 1 .* ones(2,2)) -2×2×1 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 1.0 1.0 - 1.0 1.0 - -julia> b -2×2×1 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 1.0 1.0 - 1.0 1.0 - -julia> nframes(b) -1 - -julia> for i in 2:4 - push!(b, i .* ones(2,2)) - end - -julia> b -2×2×3 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 2.0 2.0 - 2.0 2.0 - -[:, :, 2] = - 3.0 3.0 - 3.0 3.0 - -[:, :, 3] = - 4.0 4.0 - 4.0 4.0 - -julia> isfull(b) -true - -julia> nframes(b) -3 - -julia> size(b) -(2, 2, 3) -``` -""" -mutable struct CircularArrayBuffer{T,N} <: AbstractArray{T,N} - buffer::Array{T,N} - first::Int - nframes::Int - step_size::Int - - function CircularArrayBuffer{T}(d::Integer...) where {T} - N = length(d) - new{T,N}(Array{T}(undef, d...), 1, 0, N == 1 ? 1 : *(d[1:end-1]...)) - end -end - -Base.IndexStyle(::CircularArrayBuffer) = IndexLinear() -Base.size(cb::CircularArrayBuffer{<:Any,N}, i::Integer) where {N} = - i == N ? cb.nframes : size(cb.buffer, i) -Base.size(cb::CircularArrayBuffer{<:Any,N}) where {N} = ntuple(i -> size(cb, i), N) -Base.getindex(cb::CircularArrayBuffer{T,N}, i::Int) where {T,N} = - getindex(cb.buffer, _buffer_index(cb, i)) -Base.setindex!(cb::CircularArrayBuffer{T,N}, v, i::Int) where {T,N} = - setindex!(cb.buffer, v, _buffer_index(cb, i)) -capacity(cb::CircularArrayBuffer{T,N}) where {T,N} = size(cb.buffer, N) -isfull(cb::CircularArrayBuffer) = cb.nframes == capacity(cb) -Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0 - -@inline function _buffer_index(cb::CircularArrayBuffer, i::Int) - ind = (cb.first - 1) * cb.step_size + i - mod1(ind, length(cb.buffer)) -end - -@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int) - n = capacity(cb) - idx = cb.first + i - 1 - mod1(idx, n) -end - -_buffer_frame(cb::CircularArrayBuffer, I::Vector{Int}) = map(i -> _buffer_frame(cb, i), I) - -function Base.empty!(cb::CircularArrayBuffer) - cb.nframes = 0 - cb -end - -""" - update!(cb::CircularArrayBuffer{T,N}, data::AbstractArray) - -`update!` the last frame of `cb` with data. -""" -function RLBase.update!(cb::CircularArrayBuffer{T,N}, data) where {T,N} - select_last_dim(cb.buffer, _buffer_frame(cb, cb.nframes)) .= data - cb -end - -function RLBase.update!(cb::CircularArrayBuffer{T,1}, data) where {T} - cb.buffer[_buffer_frame(cb, cb.nframes)] = data - cb -end - -function Base.push!(cb::CircularArrayBuffer, data) - # length(data) == cb.step_size || throw(ArgumentError("length of , $(cb.step_size) != $(length(data))")) - push!(cb, missing) - update!(cb, data) - cb -end - -function Base.push!(cb::CircularArrayBuffer{T,N}, ::Missing) where {T,N} - if cb.nframes == capacity(cb) - cb.first = (cb.first == capacity(cb) ? 1 : cb.first + 1) - else - cb.nframes += 1 - end - cb -end - -function Base.pop!(cb::CircularArrayBuffer) - res = select_last_frame(cb) - if cb.nframes <= 0 - throw(ArgumentError("buffer must be non-empty")) - else - cb.nframes -= 1 - end - res -end - -frame_type(::CircularArrayBuffer{T,N}) where {T,N} = Array{T,N - 1} -frame_type(::CircularArrayBuffer{T,1}) where {T} = T diff --git a/src/utils/device.jl b/src/utils/device.jl index f72bdca..f3b3fd1 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -9,33 +9,9 @@ using Random import CUDA: device send_to_host(x) = send_to_device(Val(:cpu), x) -send_to_device(::Val{:cpu}, x) = x # cpu(x) is not very efficient! So by default we do nothing here. -send_to_device(::Val{:cpu}, x::CuArray) = adapt(Array, x) -send_to_device(::Val{:gpu}, x) = Flux.fmap(a -> adapt(CuArray{Float32}, a), x) - -const KnownArrayVariants = Union{ - SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - Base.ReshapedArray{ - <:Any, - <:Any, - <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - }, - Base.ReshapedArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - SubArray{ - <:Any, - <:Any, - <:Base.ReshapedArray{ - <:Any, - <:Any, - <:SubArray{<:Any,<:Any,<:Union{CircularArrayBuffer,ElasticArray}}, - }, - }, -} - -# https://github.com/JuliaReinforcementLearning/ReinforcementLearningCore.jl/issues/130 -send_to_device(::Val{:cpu}, x::KnownArrayVariants) = Array(x) -send_to_device(::Val{:gpu}, x::Union{KnownArrayVariants,ElasticArray}) = CuArray(x) +send_to_device(::Val{:cpu}, m) = fmap(x -> adapt(Array, x), m) +send_to_device(::Val{:gpu}, m) = fmap(CUDA.cu, m) """ device(model) @@ -50,6 +26,8 @@ device(::Array) = Val(:cpu) device(x::Tuple{}) = nothing device(x::NamedTuple{(),Tuple{}}) = nothing device(x::ElasticArray) = device(x.data) +device(x::SubArray) = device(parent(x)) +device(x::Base.ReshapedArray) = device(parent(x)) function device(x::Random.AbstractRNG) if x isa CUDA.CURAND.RNG @@ -60,7 +38,7 @@ function device(x::Random.AbstractRNG) end function device(x::Union{Tuple,NamedTuple}) - d1 = device(x[1]) + d1 = device(first(x)) if isnothing(d1) device(Base.tail(x)) else diff --git a/src/utils/utils.jl b/src/utils/utils.jl index be1556d..f350076 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,5 +1,4 @@ include("printing.jl") include("base.jl") -include("circular_array_buffer.jl") include("device.jl") include("sum_tree.jl") diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..988160a --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,10 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" +ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/components/agents.jl b/test/components/agents.jl index bf6c592..7d9abc5 100644 --- a/test/components/agents.jl +++ b/test/components/agents.jl @@ -1,19 +1,12 @@ @testset "Agent" begin - action_space = DiscreteSpace(3) + env = CartPoleEnv(;T=Float32) agent = Agent(; - policy = RandomPolicy(action_space), - trajectory = CircularCompactSARTSATrajectory(; + policy = RandomPolicy(env), + trajectory = CircularArraySARTTrajectory(; capacity = 10_000, - state_type = Float32, - state_size = (4,), + state = Vector{Float32} => (4,), ), ) - @testset "loading/saving Agent" begin - mktempdir() do dir - RLCore.save(dir, agent) - @test length(readdir(dir)) != 0 - agent = RLCore.load(dir, Agent) - end - end + # TODO: test de/serialization end diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl index 3745cab..ed09ab3 100644 --- a/test/components/trajectories.jl +++ b/test/components/trajectories.jl @@ -16,201 +16,45 @@ @test t[:reward] == Bool[] end - @testset "SharedTrajectory" begin - t = SharedTrajectory(Int[], :state) - @test (:state, :next_state, :full_state) == keys(t) - @test haskey(t, :state) - @test haskey(t, :next_state) - @test haskey(t, :full_state) - @test t[:state] == Int[] - @test t[:next_state] == Int[] - @test t[:full_state] == Int[] - push!(t; state = 1, next_state = 2) - @test t[:state] == [1] - @test t[:next_state] == [2] - @test t[:full_state] == [1, 2] - empty!(t) - @test t[:state] == Int[] - @test t[:next_state] == Int[] - @test t[:full_state] == Int[] - end - - @testset "EpisodicTrajectory" begin - t = EpisodicTrajectory( - Trajectory(; state = Vector{Int}(), reward = Vector{Bool}()), - :reward, - ) - - @test isfull(t) == false - - @test (:state, :reward) == keys(t) - @test haskey(t, :state) - @test haskey(t, :reward) - push!(t; state = 3, reward = true) - - @test isfull(t) == true - - push!(t; state = 4, reward = false) - @test t[:state] == [3, 4] - @test t[:reward] == [true, false] - pop!(t) - @test t[:state] == [3] - @test t[:reward] == [true] - empty!(t) - @test t[:state] == Int[] - @test t[:reward] == Bool[] + @testset "DummyTrajectory" begin + @test keys(DUMMY_TRAJECTORY) == () end - @testset "CombinedTrajectory" begin - t = CircularCompactPSALRTSALTrajectory(; + @testset "CircularArraySARTTrajectory" begin + t = CircularArraySARTTrajectory(; capacity = 3, - legal_actions_mask_size = (2,), - ) - push!(t; state = 1, action = 1, legal_actions_mask = [false, false]) - push!( - t; - reward = 0.0f0, - terminal = false, - priority = 100, - state = 2, - action = 2, - legal_actions_mask = [false, true], - ) - - @test t[:state] == [1] - @test t[:action] == [1] - @test t[:legal_actions_mask] == [false false]' - @test t[:reward] == [0.0f0] - @test t[:terminal] == [false] - @test t[:priority] == [100] - @test t[:next_state] == [2] - @test t[:next_action] == [2] - @test t[:next_legal_actions_mask] == [false true]' - @test t[:full_state] == [1, 2] - @test t[:full_action] == [1, 2] - @test t[:full_legal_actions_mask] == [ - false false - false true - ] - - push!( - t; - reward = 1.0f0, - terminal = true, - priority = 200, - state = 3, - action = 3, - legal_actions_mask = [true, true], + state = Vector{Int} => (4, ), + action = Int => (), + reward = Float32 => (), + terminal = Bool => () ) - @test t[:state] == [1, 2] - @test t[:action] == [1, 2] - @test t[:legal_actions_mask] == [ - false false - false true - ] - @test t[:reward] == [0.0f0, 1.0f0] - @test t[:terminal] == [false, true] - @test t[:priority] == [100, 200] - @test t[:next_state] == [2, 3] - @test t[:next_action] == [2, 3] - @test t[:next_legal_actions_mask] == [ - false true - true true - ] - @test t[:full_state] == [1, 2, 3] - @test t[:full_action] == [1, 2, 3] - @test t[:full_legal_actions_mask] == [ - false false true - false true true - ] - - pop!(t) - - @test t[:state] == [1] - @test t[:action] == [1] - @test t[:legal_actions_mask] == [false false]' - @test t[:reward] == [0.0f0] - @test t[:terminal] == [false] - @test t[:priority] == [100] - @test t[:next_state] == [2] - @test t[:next_action] == [2] - @test t[:next_legal_actions_mask] == [false true]' - @test t[:full_state] == [1, 2] - @test t[:full_action] == [1, 2] - @test t[:full_legal_actions_mask] == [ - false false - false true - ] - - - empty!(t) - - @test t[:state] == [] - @test t[:action] == [] - @test t[:reward] == [] - @test t[:terminal] == [] - @test t[:next_state] == [] - @test t[:next_action] == [] - @test t[:full_state] == [] - @test t[:full_action] == [] - end - - @testset "VectCompactSARTSATrajectory" begin - t = VectCompactSARTSATrajectory(; - state_type = Vector{Float32}, - action_type = Int, - reward_type = Float32, - terminal_type = Bool, - ) - push!(t; state = Float32[1, 1], action = 1) - push!(t; reward = 1.0f0, terminal = false, state = Float32[2, 2], action = 2) - push!(t; reward = 2.0f0, terminal = true, state = Float32[3, 3], action = 3) + @test length(t) == 0 + push!(t; state=ones(Int, 4), action = 1) + @test length(t) == 0 + push!(t; reward=1.f0, terminal=false, state=2 * ones(Int, 4), action = 2) + @test length(t) == 1 - @test t[:state] == [Float32[1, 1], Float32[2, 2]] - @test t[:action] == [1, 2] - @test t[:reward] == [1.0f0, 2.0f0] - @test t[:terminal] == [false, true] - @test t[:next_state] == [Float32[2, 2], Float32[3, 3]] - @test t[:next_action] == [2, 3] - end + @test t[:state] == hcat(ones(Int, 4), 2*ones(Int, 4)) - @testset "ElasticCompactSARTSATrajectory" begin - t = ElasticCompactSARTSATrajectory(; - state_type = Float32, - state_size = (2,), - action_type = Int, - reward_type = Float32, - terminal_type = Bool, - ) - push!(t; state = Float32[1, 1], action = 1) - push!(t; reward = 1.0f0, terminal = false, state = Float32[2, 2], action = 2) - push!(t; reward = 2.0f0, terminal = true, state = Float32[3, 3], action = 3) - - @test t[:state] == Float32[1 2; 1 2] - @test t[:action] == [1, 2] - @test t[:reward] == [1.0f0, 2.0f0] - @test t[:terminal] == [false, true] - @test t[:next_state] == Float32[2 3; 2 3] - @test t[:next_action] == [2, 3] + push!(t; reward=2.f0, terminal=false, state=3 * ones(Int, 4), action = 3) + @test length(t) == 2 - @test pop!(t) == - (reward = 2.0f0, terminal = true, state = Float32[3.0, 3.0], action = 3) - push!(t; reward = 1.0f0, terminal = false, state = Float32[2, 2], action = 2) - @test t[:state] == Float32[1 2; 1 2] - @test t[:action] == [1, 2] - @test t[:reward] == [1.0f0, 1.0f0] - @test t[:terminal] == [false, false] - @test t[:next_state] == Float32[2 2; 2 2] - @test t[:next_action] == [2, 2] + push!(t; reward=3.f0, terminal=false, state=4 * ones(Int, 4), action = 4) + @test length(t) == 3 + @test t[:state] == [j for i in 1:4, j in 1:4] + @test t[:reward] == [1, 2, 3] - empty!(t) - @test length(t[:state]) == 0 + # test circle works as expected + push!(t; reward=4.f0, terminal=true, state=5 * ones(Int, 4), action = 5) + @test length(t) == 3 + @test t[:state] == [j for i in 1:4, j in 2:5] + @test t[:reward] == [2, 3, 4] end @testset "ReservoirTrajectory" begin # test length - t = ReservoirTrajectory(3, :a => Array{Float64,2}, :b => Bool) + t = ReservoirTrajectory(3; a = Array{Float64,2}, b = Bool) push!(t; a = rand(2, 3), b = rand(Bool)) @test length(t) == 1 push!(t; a = rand(2, 3), b = rand(Bool)) @@ -230,7 +74,7 @@ k, n, N = 3, 10, 10000 stats = Dict(i => 0 for i in 1:n) for _ in 1:N - t = ReservoirTrajectory(k, :a => Array{Int,2}, :b => Int) + t = ReservoirTrajectory(k; a = Array{Int,2}, b = Int) for i in 1:n push!(t; a = i .* ones(Int, 2, 3), b = i) end diff --git a/test/core/core.jl b/test/core/core.jl index 1ec11aa..3465412 100644 --- a/test/core/core.jl +++ b/test/core/core.jl @@ -2,13 +2,12 @@ env = CartPoleEnv{Float32}() |> StateOverriddenEnv(deepcopy) agent = Agent(; policy = RandomPolicy(env), - trajectory = CircularCompactSARTSATrajectory(; + trajectory = CircularArraySARTTrajectory(; capacity = 10_000, - state_type = Float32, - state_size = (4,), + state = Vector{Float32} => (4,), ), ) - N_EPISODE = 10000 + N_EPISODE = 10_000 hook = TotalRewardPerEpisode() run(agent, env, StopAfterEpisode(N_EPISODE), hook) diff --git a/test/utils/circular_array_buffer.jl b/test/utils/circular_array_buffer.jl deleted file mode 100644 index 71a636f..0000000 --- a/test/utils/circular_array_buffer.jl +++ /dev/null @@ -1,182 +0,0 @@ -@testset "CircularArrayBuffer" begin - A = ones(2, 2) - C = ones(Float32, 2, 2) - @testset "1D Int" begin - b = CircularArrayBuffer{Int}(3) - - @test eltype(b) == Int - @test capacity(b) == 3 - @test isfull(b) == false - @test isempty(b) == true - @test length(b) == 0 - @test nframes(b) == 0 - @test size(b) == (0,) - # element must has the exact same length with the element of buffer - @test_throws Exception push!(b, [1, 2]) - - for x in 1:3 - push!(b, x) - end - - @test capacity(b) == 3 - @test isfull(b) == true - @test length(b) == 3 - @test nframes(b) == 3 - @test size(b) == (3,) - @test b[1] == 1 - @test b[end] == 3 - @test b[1:end] == [1, 2, 3] - - for x in 4:5 - push!(b, x) - end - - @test capacity(b) == 3 - @test length(b) == 3 - @test nframes(b) == 3 - @test size(b) == (3,) - @test b[1] == 3 - @test b[end] == 5 - @test b[1:end] == [3, 4, 5] - - empty!(b) - @test isfull(b) == false - @test isempty(b) == true - @test length(b) == 0 - @test nframes(b) == 0 - @test size(b) == (0,) - - push!(b, 6) - @test isfull(b) == false - @test isempty(b) == false - @test length(b) == 1 - @test nframes(b) == 1 - @test size(b) == (1,) - @test b[1] == 6 - - push!(b, 7) - push!(b, 8) - @test isfull(b) == true - @test isempty(b) == false - @test length(b) == 3 - @test nframes(b) == 3 - @test size(b) == (3,) - @test b[[1, 2, 3]] == [6, 7, 8] - - push!(b, 9) - @test isfull(b) == true - @test isempty(b) == false - @test length(b) == 3 - @test nframes(b) == 3 - @test size(b) == (3,) - @test b[[1, 2, 3]] == [7, 8, 9] - - update!(b, 0) - @test length(b) == 3 - @test b[[1, 2, 3]] == [7, 8, 0] - - update!(b, 1) - @test length(b) == 3 - @test b[[1, 2, 3]] == [7, 8, 1] - - x = pop!(b) - @test x == 1 - @test length(b) == 2 - @test b[[1, 2]] == [7, 8] - - x = pop!(b) - @test x == 8 - @test length(b) == 1 - @test b[1] == 7 - - x = pop!(b) - @test x == 7 - @test length(b) == 0 - - @test_throws ArgumentError pop!(b) - end - - @testset "2D Float64" begin - b = CircularArrayBuffer{Float64}(2, 2, 3) - - @test eltype(b) == Float64 - @test capacity(b) == 3 - @test isfull(b) == false - @test length(b) == 0 - @test nframes(b) == 0 - @test size(b) == (2, 2, 0) - - for x in 1:3 - push!(b, x * A) - end - - @test capacity(b) == 3 - @test isfull(b) == true - @test nframes(b) == 3 - @test length(b) == 2 * 2 * 3 - @test size(b) == (2, 2, 3) - for i in 1:3 - @test b[:, :, i] == i * A - end - @test b[:, :, end] == 3 * A - - for x in 4:5 - push!(b, x * ones(2, 2)) # collection is also OK - end - - @test capacity(b) == 3 - @test length(b) == 2 * 2 * 3 - @test nframes(b) == 3 - @test size(b) == (2, 2, 3) - @test b[:, :, 1] == 3 * A - @test b[:, :, end] == 5 * A - - @test b == reshape([c for x in 3:5 for c in x * A], 2, 2, 3) - - push!(b, 6 * ones(Float32, 2, 2)) - push!(b, 7 * ones(Int, 2, 2)) - @test b == reshape([c for x in 5:7 for c in x * A], 2, 2, 3) - end - - @testset "2D Float32" begin - b = CircularArrayBuffer{Float32}(2, 2, 3) - - @test eltype(b) == Float32 - @test capacity(b) == 3 - @test isfull(b) == false - @test length(b) == 0 - @test nframes(b) == 0 - @test size(b) == (2, 2, 0) - - for x in 1:3 - push!(b, x * C) - end - - @test capacity(b) == 3 - @test isfull(b) == true - @test nframes(b) == 3 - @test length(b) == 2 * 2 * 3 - @test size(b) == (2, 2, 3) - for i in 1:3 - @test b[:, :, i] == i * C - end - @test b[:, :, end] == 3 * C - - for x in 4:5 - push!(b, x * ones(Float32, 2, 2)) # collection is also OK - end - - @test capacity(b) == 3 - @test length(b) == 2 * 2 * 3 - @test nframes(b) == 3 - @test size(b) == (2, 2, 3) - @test b[:, :, 1] == 3 * C - @test b[:, :, end] == 5 * C - - @test b == reshape([c for x in 3:5 for c in x * C], 2, 2, 3) - - push!(b, 6 * ones(Float64, 2, 2)) - push!(b, 7 * ones(Int, 2, 2)) - @test b == reshape([c for x in 5:7 for c in x * C], 2, 2, 3) - end -end diff --git a/test/utils/utils.jl b/test/utils/utils.jl index 7a7715c..e122ec8 100644 --- a/test/utils/utils.jl +++ b/test/utils/utils.jl @@ -1,3 +1,2 @@ include("base.jl") -include("circular_array_buffer.jl") include("device.jl") From 33ca83105da39814a7e321e78038b374e187b124 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 28 Nov 2020 18:26:00 +0800 Subject: [PATCH 03/12] clean up utils --- src/utils/base.jl | 94 ++++++++++++++++++++++++++----------------- src/utils/device.jl | 2 +- src/utils/printing.jl | 48 ++++++++++++---------- test/utils/base.jl | 3 -- 4 files changed, 85 insertions(+), 62 deletions(-) diff --git a/src/utils/base.jl b/src/utils/base.jl index b5e18ac..29506d2 100644 --- a/src/utils/base.jl +++ b/src/utils/base.jl @@ -1,25 +1,18 @@ export nframes, - frame_type, select_last_dim, select_last_frame, consecutive_view, find_all_max, - huber_loss, - huber_loss_unreduced, discount_rewards, discount_rewards!, discount_rewards_reduced, generalized_advantage_estimation, generalized_advantage_estimation!, - logitcrossentropy_unreduced, - flatten_batch, - unflatten_batch + flatten_batch using StatsBase nframes(a::AbstractArray{T,N}) where {T,N} = size(a, N) -frame_type(::Array{T,N}) where {T,N} = Array{T,N - 1} -frame_type(::Vector{T}) where {T} = T select_last_dim(xs::AbstractArray{T,N}, inds) where {T,N} = @views xs[ntuple(_ -> (:), N - 1)..., inds] @@ -54,30 +47,80 @@ julia> flatten_batch(x) 2 4 6 8 10 12 ``` """ -flatten_batch(x::AbstractArray) = - reshape(x, (size(x) |> reverse |> Base.tail |> Base.tail |> reverse)..., :) # much faster than `reshape(x, size(x)[1:end-2]..., :)` +flatten_batch(x::AbstractArray) = reshape(x, size(x)[1:end-2]..., :) -unflatten_batch(x::AbstractArray, i::Int...) = - reshape(x, (size(x) |> reverse |> Base.tail |> reverse)..., i...) +""" + consecutive_view(x::AbstractArray, inds; n_stack = nothing, n_horizon = nothing) + +By default, it behaves the same with `select_last_dim(x, inds)`. +If `n_stack` is set to an int, then for each frame specified by `inds`, +the previous `n_stack` frames (including the current one) are concatenated as a new dimension. +If `n_horizon` is set to an int, then for each frame specified by `inds`, +the next `n_horizon` frames (including the current one) are concatenated as a new dimension. + +# Example + +```julia +julia> x = collect(1:5) +5-element Array{Int64,1}: + 1 + 2 + 3 + 4 + 5 + +julia> consecutive_view(x, [2,4]) # just the same with `select_last_dim(x, [2,4])` +2-element view(::Array{Int64,1}, [2, 4]) with eltype Int64: + 2 + 4 + +julia> consecutive_view(x, [2,4];n_stack = 2) +2×2 view(::Array{Int64,1}, [1 3; 2 4]) with eltype Int64: + 1 3 + 2 4 + +julia> consecutive_view(x, [2,4];n_horizon = 2) +2×2 view(::Array{Int64,1}, [2 4; 3 5]) with eltype Int64: + 2 4 + 3 5 + +julia> consecutive_view(x, [2,4];n_horizon = 2, n_stack=2) # note the order here, first we stack, then we apply the horizon +2×2×2 view(::Array{Int64,1}, [1 2; 2 3] + +[3 4; 4 5]) with eltype Int64: +[:, :, 1] = + 1 2 + 2 3 + +[:, :, 2] = + 3 4 + 4 5 +``` +See also [Frame Skipping and Preprocessing for Deep Q networks](https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/) +to gain a better understanding of state stacking and n-step learning. +""" consecutive_view( cb::AbstractArray, inds::Vector{Int}; n_stack = nothing, n_horizon = nothing, ) = consecutive_view(cb, inds, n_stack, n_horizon) -consecutive_view(cb::AbstractArray, inds::Vector{Int}, ::Nothing, ::Nothing) = - select_last_dim(cb, inds) + +consecutive_view(cb::AbstractArray, inds::Vector{Int}, ::Nothing, ::Nothing) = select_last_dim(cb, inds) + consecutive_view(cb::AbstractArray, inds::Vector{Int}, n_stack::Int, ::Nothing) = select_last_dim( cb, reshape([i for x in inds for i in x-n_stack+1:x], n_stack, length(inds)), ) + consecutive_view(cb::AbstractArray, inds::Vector{Int}, ::Nothing, n_horizon::Int) = select_last_dim( cb, reshape([i for x in inds for i in x:x+n_horizon-1], n_horizon, length(inds)), ) + consecutive_view(cb::AbstractArray, inds::Vector{Int}, n_stack::Int, n_horizon::Int) = select_last_dim( cb, @@ -107,29 +150,6 @@ _rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m) Base.findmax(A::AbstractVector, mask::AbstractVector{Bool}) = findmax(i -> A[i], view(keys(A), mask)) -function logitcrossentropy_unreduced(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) - return vec(-sum(y .* logsoftmax(logŷ), dims = 1)) -end - -""" - huber_loss_unreduced(labels, predictions; δ = 1.0f0) - -Similar to [`huber_loss`](@ref), but it doesn't do the `mean` operation in the last step. -""" -function huber_loss_unreduced(labels, predictions; δ = 1.0f0) - abs_error = abs.(predictions .- labels) - quadratic = min.(abs_error, δ) - linear = abs_error .- quadratic - 0.5f0 .* quadratic .* quadratic .+ δ .* linear -end - -""" - huber_loss(labels, predictions; δ = 1.0f0) - -See [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) -""" -huber_loss(labels, predictions; δ = 1.0f0) = - huber_loss_unreduced(labels, predictions; δ = δ) |> mean const VectorOrMatrix = Union{AbstractMatrix,AbstractVector} diff --git a/src/utils/device.jl b/src/utils/device.jl index f3b3fd1..71a2c99 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -1,10 +1,10 @@ export device, send_to_host, send_to_device -using ElasticArrays using Flux using CUDA using Adapt using Random +using ElasticArrays import CUDA: device diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 772912f..0a53316 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -1,8 +1,6 @@ -export StructTree - using AbstractTrees using Random -using ProgressMeter +using ProgressMeter:Progress const AT = AbstractTrees @@ -10,20 +8,35 @@ struct StructTree{X} x::X end -AT.children(t::StructTree{X}) where {X} = - Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X)) -AT.children( - t::StructTree{T}, -) where { - T<:Union{AbstractArray,AbstractDict,MersenneTwister,ProgressMeter.Progress,Function}, -} = () +is_expand(x::T) where T = is_expand(T) +is_expand(::Type{T}) where T = true + +is_expand(::AbstractArray) = false +is_expand(::AbstractDict) = false +is_expand(::AbstractRNG) = false +is_expand(::Progress) = false +is_expand(::Function) = false +is_expand(::UnionAll) = false + +function AT.children(t::StructTree{X}) where {X} + if is_expand(t.x) + Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X)) + else + () + end +end + AT.children(t::Pair{Symbol,<:StructTree}) = children(last(t)) -AT.children(t::StructTree{UnionAll}) = () +AT.printnode(io::IO, t::StructTree{T}) where {T} = print(io, T.name) AT.printnode(io::IO, t::StructTree{<:Union{Number,Symbol}}) = print(io, t.x) AT.printnode(io::IO, t::StructTree{UnionAll}) = print(io, t.x) -AT.printnode(io::IO, t::StructTree{T}) where {T} = print(io, T.name) -AT.printnode(io::IO, t::StructTree{<:AbstractArray}) where {T} = summary(io, t.x) +AT.printnode(io::IO, t::StructTree{<:AbstractArray}) = summary(io, t.x) + +function AT.printnode(io::IO, t::Pair{Symbol,<:StructTree}) + print(io, first(t), " => ") + AT.printnode(io, last(t)) +end function AT.printnode(io::IO, t::StructTree{String}) s = t.x @@ -43,11 +56,4 @@ function AT.printnode(io::IO, t::StructTree{String}) end end -function AT.printnode(io::IO, t::Pair{Symbol,<:StructTree}) - print(io, first(t), " => ") - AT.printnode(io, last(t)) -end - -function AT.printnode(io::IO, t::Pair{Symbol,<:StructTree{<:Tuple}}) - print(io, first(t)) -end +AT.printnode(io::IO, t::Pair{Symbol,<:StructTree{<:Tuple}}) = print(io, first(t)) diff --git a/test/utils/base.jl b/test/utils/base.jl index 9623c90..1d8618d 100644 --- a/test/utils/base.jl +++ b/test/utils/base.jl @@ -82,9 +82,6 @@ x = rand(2, 3, 4) y = flatten_batch(x) @test size(y) == (2, 12) - - z = unflatten_batch(y, 3, 4) - @test x == z end @testset "discount_rewards" begin From 2fe081d5e5bda4f57595fbc2b502e2be10430c39 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 28 Nov 2020 18:52:39 +0800 Subject: [PATCH 04/12] revisit extensions --- src/extensions/ElasticArrays.jl | 2 +- src/extensions/ReinforcementLearningBase.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/extensions/ElasticArrays.jl b/src/extensions/ElasticArrays.jl index 80dde92..d3c9037 100644 --- a/src/extensions/ElasticArrays.jl +++ b/src/extensions/ElasticArrays.jl @@ -5,7 +5,7 @@ Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x]) Base.empty!(a::ElasticArray) = ElasticArrays.resize_lastdim!(a, 0) function Base.pop!(a::ElasticArray) - last_frame = select_last_frame(a) |> copy # !!! ensure that we will not access invalid data + last_frame = select_last_frame(a) #= |> copy =# # !!! ensure that we will not access invalid data ElasticArrays.resize!(a.data, length(a.data) - a.kernel_length.divisor) last_frame end diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index faac067..a4eff67 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -12,7 +12,7 @@ Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10)) -AbstractTrees.children(t::StructTree{<:AbstractEnv}) = () +is_expand(::AbstractEnv) = false AbstractTrees.printnode(io::IO, t::StructTree{<:AbstractEnv}) = print( io, From f6f0ff70a22f628ab1580120593a1253a30a8026 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 28 Nov 2020 19:20:37 +0800 Subject: [PATCH 05/12] update processors --- .../approximators/neural_network_approximator.jl | 6 +++--- src/components/processors.jl | 15 ++++----------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/components/approximators/neural_network_approximator.jl b/src/components/approximators/neural_network_approximator.jl index 6f3307f..cde597e 100644 --- a/src/components/approximators/neural_network_approximator.jl +++ b/src/components/approximators/neural_network_approximator.jl @@ -1,6 +1,7 @@ export NeuralNetworkApproximator, ActorCritic using Flux +import Functors: functor """ NeuralNetworkApproximator(;kwargs) @@ -19,8 +20,7 @@ end (app::NeuralNetworkApproximator)(x) = app.model(x) -# !!! watch https://github.com/FluxML/Functors.jl/blob/master/src/functor.jl#L2 -Flux.functor(x::NeuralNetworkApproximator) = +functor(x::NeuralNetworkApproximator) = (model = x.model,), y -> NeuralNetworkApproximator(y.model, x.optimizer) device(app::NeuralNetworkApproximator) = device(app.model) @@ -48,7 +48,7 @@ Base.@kwdef struct ActorCritic{A,C,O} <: AbstractApproximator optimizer::O = ADAM() end -Flux.functor(x::ActorCritic) = +functor(x::ActorCritic) = (actor = x.actor, critic = x.critic), y -> ActorCritic(y.actor, y.critic, x.optimizer) RLBase.update!(app::ActorCritic, gs) = Flux.Optimise.update!(app.optimizer, params(app), gs) diff --git a/src/components/processors.jl b/src/components/processors.jl index c414758..de62a71 100644 --- a/src/components/processors.jl +++ b/src/components/processors.jl @@ -44,21 +44,14 @@ end function (p::StackFrames{T,N})(state::AbstractArray) where {T,N} push!(p.buffer, state) - p.buffer + p end -# !!! side effect? -function Base.push!( - cb::CircularArrayBuffer{T,N}, - stacked_data::CircularArrayBuffer{T,N}, -) where {T,N} - push!(cb, select_last_frame(stacked_data)) +function Base.push!(cb::CircularArrayBuffer, p::StackFrames) + push!(cb, select_last_frame(p.buffer)) end function RLBase.reset!(p::StackFrames{T,N}) where {T,N} - empty!(p.buffer) - for _ in 1:CircularArrayBuffers.capacity(p.buffer) - push!(p.buffer, zeros(T, size(p.buffer)[1:N-1])) - end + fill!(p.buffer, zero(T)) p end From 0bc953feda9a610f59b4b335465950f8cebe9063 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Dec 2020 10:57:41 +0800 Subject: [PATCH 06/12] pass tests --- Project.toml | 1 + src/components/agents/agents.jl | 2 -- src/components/components.jl | 1 - src/components/{agents => policies}/agent.jl | 0 src/components/{agents => policies}/base.jl | 0 src/components/policies/policies.jl | 2 ++ src/components/processors.jl | 9 +++++++-- src/core/experiment.jl | 4 ++-- src/core/run.jl | 12 +++++++++--- 9 files changed, 21 insertions(+), 10 deletions(-) delete mode 100644 src/components/agents/agents.jl rename src/components/{agents => policies}/agent.jl (100%) rename src/components/{agents => policies}/base.jl (100%) diff --git a/Project.toml b/Project.toml index 0d2bcb1..be3e820 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/components/agents/agents.jl b/src/components/agents/agents.jl deleted file mode 100644 index 3edb072..0000000 --- a/src/components/agents/agents.jl +++ /dev/null @@ -1,2 +0,0 @@ -include("base.jl") -include("agent.jl") diff --git a/src/components/components.jl b/src/components/components.jl index de09f7b..c70444b 100644 --- a/src/components/components.jl +++ b/src/components/components.jl @@ -4,4 +4,3 @@ include("approximators/approximators.jl") include("explorers/explorers.jl") include("learners/learners.jl") include("policies/policies.jl") -include("agents/agents.jl") diff --git a/src/components/agents/agent.jl b/src/components/policies/agent.jl similarity index 100% rename from src/components/agents/agent.jl rename to src/components/policies/agent.jl diff --git a/src/components/agents/base.jl b/src/components/policies/base.jl similarity index 100% rename from src/components/agents/base.jl rename to src/components/policies/base.jl diff --git a/src/components/policies/policies.jl b/src/components/policies/policies.jl index 6c04293..2f6dd4e 100644 --- a/src/components/policies/policies.jl +++ b/src/components/policies/policies.jl @@ -1,4 +1,6 @@ +include("base.jl") include("V_based_policy.jl") include("Q_based_policy.jl") include("off_policy.jl") include("static_policy.jl") +include("agent.jl") diff --git a/src/components/processors.jl b/src/components/processors.jl index de62a71..b1d4869 100644 --- a/src/components/processors.jl +++ b/src/components/processors.jl @@ -3,6 +3,7 @@ export StackFrames, ResizeImage using ImageTransformations: imresize! import CircularArrayBuffers using CircularArrayBuffers:CircularArrayBuffer +using MacroTools:@forward """ ResizeImage(img::Array{T, N}) @@ -26,12 +27,16 @@ end """ StackFrames(::Type{T}=Float32, d::Int...) -Use a pre-initialized [`CircularArrayBuffer`](@ref) to store the latest several states specified by `d`. Before processing any observation, the buffer is filled with `zero{T}`. +Use a pre-initialized [`CircularArrayBuffer`](@ref) to store the latest several states specified by `d`. Before processing any observation, the buffer is filled with `zero{T} +by default. """ -struct StackFrames{T,N} +struct StackFrames{T,N} <: AbstractArray{T,N} buffer::CircularArrayBuffer{T,N} end +@forward StackFrames.buffer Base.size, Base.getindex +Base.IndexStyle(x::StackFrames) = IndexStyle(x.buffer) + StackFrames(d::Int...) = StackFrames(Float32, d...) function StackFrames(::Type{T}, d::Vararg{Int,N}) where {T,N} diff --git a/src/core/experiment.jl b/src/core/experiment.jl index 8e8efff..d2405cf 100644 --- a/src/core/experiment.jl +++ b/src/core/experiment.jl @@ -1,14 +1,14 @@ export @experiment_cmd, @E_cmd, Experiment -using BSON using Markdown +using Dates Base.@kwdef mutable struct Experiment agent::Any env::Any stop_condition::Any hook::Any - description::String + description::String = "Experiment created at $(now())" end function Base.show(io::IO, x::Experiment) diff --git a/src/core/run.jl b/src/core/run.jl index 19add12..ba16806 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -2,9 +2,15 @@ export expected_policy_values import Base: run -run(agent, env::AbstractEnv, args...) = _run(agent, env, args...) +function run(agent::Agent, env::AbstractEnv, stop_condition=StopAfterEpisode(1), hook=EmptyHook()) + check(agent, env) + _run(agent, env, stop_condition, hook) +end + +"Inject some customized checkings here by overwriting this function" +function check(agent, env) end -_run(agent, env, args...) = _run(DynamicStyle(env), NumAgentStyle(env), agent, env, args...) +_run(agent, env, stop_condition, hook) = _run(DynamicStyle(env), NumAgentStyle(env), agent, env, stop_condition, hook) function _run( ::Sequential, @@ -12,7 +18,7 @@ function _run( agent::Agent, env::AbstractEnv, stop_condition, - hook::AbstractHook = EmptyHook(), + hook::AbstractHook, ) while true # run episodes forever From 4c950c693d05fbbca4deb5cccb2ae7cfb8a8ec34 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Dec 2020 17:34:32 +0800 Subject: [PATCH 07/12] reorganize code structure --- src/ReinforcementLearningCore.jl | 2 +- src/components/components.jl | 6 --- src/components/policies/V_based_policy.jl | 47 ------------------- src/components/policies/off_policy.jl | 15 ------ src/components/policies/policies.jl | 6 --- src/components/policies/static_policy.jl | 24 ---------- .../policies => policies/agents}/agent.jl | 12 ++--- src/policies/agents/agents.jl | 3 ++ .../policies => policies/agents}/base.jl | 0 .../trajectories/abstract_trajectory.jl | 2 + .../trajectories/reservoir_trajectory.jl | 0 .../agents}/trajectories/trajectories.jl | 0 .../agents}/trajectories/trajectory.jl | 0 .../trajectories/trajectory_extension.jl | 17 ++++--- src/policies/policies.jl | 2 + .../explorers/UCB_explorer.jl | 0 .../explorers/abstract_explorer.jl | 0 .../explorers/batch_explorer.jl | 0 .../explorers/epsilon_greedy_explorer.jl | 0 .../q_based_policies}/explorers/explorers.jl | 0 .../explorers/gumbel_softmax_explorer.jl | 0 .../explorers/weighted_explorer.jl | 0 .../explorers/weighted_softmax_explorer.jl | 0 .../learners/abstract_learner.jl | 0 .../approximators/abstract_approximator.jl | 0 .../learners}/approximators/approximators.jl | 0 .../neural_network_approximator.jl | 0 .../approximators/tabular_approximator.jl | 0 .../q_based_policies}/learners/learners.jl | 1 + .../learners/tabular_learner.jl | 0 .../q_based_policies/q_based_policies.jl | 2 + .../q_based_policies/q_based_policy.jl} | 6 ++- src/{components => utils}/processors.jl | 0 src/utils/utils.jl | 1 + 34 files changed, 30 insertions(+), 116 deletions(-) delete mode 100644 src/components/components.jl delete mode 100644 src/components/policies/V_based_policy.jl delete mode 100644 src/components/policies/off_policy.jl delete mode 100644 src/components/policies/policies.jl delete mode 100644 src/components/policies/static_policy.jl rename src/{components/policies => policies/agents}/agent.jl (92%) create mode 100644 src/policies/agents/agents.jl rename src/{components/policies => policies/agents}/base.jl (100%) rename src/{components => policies/agents}/trajectories/abstract_trajectory.jl (94%) rename src/{components => policies/agents}/trajectories/reservoir_trajectory.jl (100%) rename src/{components => policies/agents}/trajectories/trajectories.jl (100%) rename src/{components => policies/agents}/trajectories/trajectory.jl (100%) rename src/{components => policies/agents}/trajectories/trajectory_extension.jl (81%) create mode 100644 src/policies/policies.jl rename src/{components => policies/q_based_policies}/explorers/UCB_explorer.jl (100%) rename src/{components => policies/q_based_policies}/explorers/abstract_explorer.jl (100%) rename src/{components => policies/q_based_policies}/explorers/batch_explorer.jl (100%) rename src/{components => policies/q_based_policies}/explorers/epsilon_greedy_explorer.jl (100%) rename src/{components => policies/q_based_policies}/explorers/explorers.jl (100%) rename src/{components => policies/q_based_policies}/explorers/gumbel_softmax_explorer.jl (100%) rename src/{components => policies/q_based_policies}/explorers/weighted_explorer.jl (100%) rename src/{components => policies/q_based_policies}/explorers/weighted_softmax_explorer.jl (100%) rename src/{components => policies/q_based_policies}/learners/abstract_learner.jl (100%) rename src/{components => policies/q_based_policies/learners}/approximators/abstract_approximator.jl (100%) rename src/{components => policies/q_based_policies/learners}/approximators/approximators.jl (100%) rename src/{components => policies/q_based_policies/learners}/approximators/neural_network_approximator.jl (100%) rename src/{components => policies/q_based_policies/learners}/approximators/tabular_approximator.jl (100%) rename src/{components => policies/q_based_policies}/learners/learners.jl (59%) rename src/{components => policies/q_based_policies}/learners/tabular_learner.jl (100%) create mode 100644 src/policies/q_based_policies/q_based_policies.jl rename src/{components/policies/Q_based_policy.jl => policies/q_based_policies/q_based_policy.jl} (92%) rename src/{components => utils}/processors.jl (100%) diff --git a/src/ReinforcementLearningCore.jl b/src/ReinforcementLearningCore.jl index fdca2e6..316309d 100644 --- a/src/ReinforcementLearningCore.jl +++ b/src/ReinforcementLearningCore.jl @@ -13,7 +13,7 @@ export RLCore include("utils/utils.jl") include("extensions/extensions.jl") -include("components/components.jl") +include("policies/policies.jl") include("core/core.jl") end # module diff --git a/src/components/components.jl b/src/components/components.jl deleted file mode 100644 index c70444b..0000000 --- a/src/components/components.jl +++ /dev/null @@ -1,6 +0,0 @@ -include("processors.jl") -include("trajectories/trajectories.jl") -include("approximators/approximators.jl") -include("explorers/explorers.jl") -include("learners/learners.jl") -include("policies/policies.jl") diff --git a/src/components/policies/V_based_policy.jl b/src/components/policies/V_based_policy.jl deleted file mode 100644 index 2f36c2d..0000000 --- a/src/components/policies/V_based_policy.jl +++ /dev/null @@ -1,47 +0,0 @@ -export VBasedPolicy - -using MacroTools: @forward - -""" - VBasedPolicy(;learner, mapping, explorer=GreedyExplorer()) - -# Key words & Fields - -- `learner`::[`AbstractLearner`](@ref), learn how to estimate state values. -- `mapping`, a customized function `(env, learner) -> action_values` -- `explorer`::[`AbstractExplorer`](@ref), decide which action to take based on action values. -""" -Base.@kwdef struct VBasedPolicy{L<:AbstractLearner,M,E<:AbstractExplorer} <: AbstractPolicy - learner::L - mapping::M - explorer::E = GreedyExplorer() -end - -(p::VBasedPolicy)(env) = p(env, ActionStyle(env)) - -(p::VBasedPolicy)(env, ::MinimalActionSet) = p.mapping(env, p.learner) |> p.explorer - -function (p::VBasedPolicy)(env, ::FullActionSet) - action_values = p.mapping(env, p.learner) - p.explorer(action_values, get_legal_actions_mask(env)) -end - -RLBase.get_prob(p::VBasedPolicy, env, action::Integer) = - get_prob(p, env, ActionStyle(env), action) - -RLBase.get_prob(p::VBasedPolicy, env, ::MinimalActionSet) = - get_prob(p.explorer, p.mapping(env, p.learner)) -RLBase.get_prob(p::VBasedPolicy, env, ::MinimalActionSet, action) = - get_prob(p.explorer, p.mapping(env, p.learner), action) - -function RLBase.get_prob(p::VBasedPolicy, env, ::FullActionSet) - action_values = p.mapping(env, p.learner) - get_prob(p.explorer, action_values, get_legal_actions_mask(env)) -end - -function RLBase.get_prob(p::VBasedPolicy, env, ::FullActionSet, action) - action_values = p.mapping(env, p.learner) - get_prob(p.explorer, action_values, get_legal_actions_mask(env), action) -end - -@forward VBasedPolicy.learner RLBase.get_priority, RLBase.update! diff --git a/src/components/policies/off_policy.jl b/src/components/policies/off_policy.jl deleted file mode 100644 index c37c50b..0000000 --- a/src/components/policies/off_policy.jl +++ /dev/null @@ -1,15 +0,0 @@ -export OffPolicy - -using MacroTools: @forward - -""" - OffPolicy(π_target::P, π_behavior::B) -> OffPolicy{P,B} -""" -Base.@kwdef struct OffPolicy{P,B} <: AbstractPolicy - π_target::P - π_behavior::B -end - -(π::OffPolicy)(env) = π.π_behavior(env) - -@forward OffPolicy.π_behavior RLBase.get_priority, RLBase.get_prob diff --git a/src/components/policies/policies.jl b/src/components/policies/policies.jl deleted file mode 100644 index 2f6dd4e..0000000 --- a/src/components/policies/policies.jl +++ /dev/null @@ -1,6 +0,0 @@ -include("base.jl") -include("V_based_policy.jl") -include("Q_based_policy.jl") -include("off_policy.jl") -include("static_policy.jl") -include("agent.jl") diff --git a/src/components/policies/static_policy.jl b/src/components/policies/static_policy.jl deleted file mode 100644 index 1d708aa..0000000 --- a/src/components/policies/static_policy.jl +++ /dev/null @@ -1,24 +0,0 @@ -export StaticPolicy - -using MacroTools: @forward - -""" - StaticPolicy(policy) - -Create a policy wrapper so that it will do nothing when calling -`update!(policy::StaticPolicy, args...)`. Usually used in the -distributed mode as a worker. -""" -struct StaticPolicy{P<:AbstractPolicy} <: AbstractPolicy - p::P -end - -(π::StaticPolicy)(env) = π.p(env) - -@forward StaticPolicy.p RLBase.get_priority, RLBase.get_prob - -RLBase.update!(p::StaticPolicy, args...) = nothing - -RLBase.update!(p::StaticPolicy, ps::Params) = update!(p.p, ps) - -Flux.@functor StaticPolicy diff --git a/src/components/policies/agent.jl b/src/policies/agents/agent.jl similarity index 92% rename from src/components/policies/agent.jl rename to src/policies/agents/agent.jl index a27589f..3371e62 100644 --- a/src/components/policies/agent.jl +++ b/src/policies/agents/agent.jl @@ -1,5 +1,5 @@ export Agent, - role + get_role import Functors:functor using Setfield: @set @@ -16,20 +16,20 @@ update the trajectory and policy appropriately in different stages and modes. - `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment - `role=RLBase.DEFAULT_PLAYER`: used to distinguish different agents """ -Base.@kwdef struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R,M} <: AbstractPolicy +Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractPolicy policy::P trajectory::T = DUMMY_TRAJECTORY role::R = RLBase.DEFAULT_PLAYER - mode::M = TRAIN_MODE + mode::Union{TrainMode,EvalMode,TestMode} = TRAIN_MODE end functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy -role(agent::Agent) = agent.role +get_role(agent::Agent) = agent.role mode(agent::Agent) = agent.mode -(agent::Agent)(env) = agent.policy(env) - +set_mode!(agent::Agent, mode::AbstractMode) = agent.mode = mode +(agent::Agent)(env) = agent.policy(env) (agent::Agent)(stage::AbstractStage, env::AbstractEnv) = agent(env, stage, mode(agent)) function (agent::Agent)(env::AbstractEnv, stage::AbstractStage, mode::AbstractMode) diff --git a/src/policies/agents/agents.jl b/src/policies/agents/agents.jl new file mode 100644 index 0000000..005a535 --- /dev/null +++ b/src/policies/agents/agents.jl @@ -0,0 +1,3 @@ +include("base.jl") +include("trajectories/trajectories.jl") +include("agent.jl") \ No newline at end of file diff --git a/src/components/policies/base.jl b/src/policies/agents/base.jl similarity index 100% rename from src/components/policies/base.jl rename to src/policies/agents/base.jl diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/policies/agents/trajectories/abstract_trajectory.jl similarity index 94% rename from src/components/trajectories/abstract_trajectory.jl rename to src/policies/agents/trajectories/abstract_trajectory.jl index 8b1b1ab..51a27cf 100644 --- a/src/components/trajectories/abstract_trajectory.jl +++ b/src/policies/agents/trajectories/abstract_trajectory.jl @@ -1,5 +1,6 @@ export AbstractTrajectory, SART, + SARTS, SARTSA """ @@ -60,4 +61,5 @@ end ##### const SART = (:state, :action, :reward, :terminal) +const SARTS = (:state, :action, :reward, :terminal, :next_state) const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) \ No newline at end of file diff --git a/src/components/trajectories/reservoir_trajectory.jl b/src/policies/agents/trajectories/reservoir_trajectory.jl similarity index 100% rename from src/components/trajectories/reservoir_trajectory.jl rename to src/policies/agents/trajectories/reservoir_trajectory.jl diff --git a/src/components/trajectories/trajectories.jl b/src/policies/agents/trajectories/trajectories.jl similarity index 100% rename from src/components/trajectories/trajectories.jl rename to src/policies/agents/trajectories/trajectories.jl diff --git a/src/components/trajectories/trajectory.jl b/src/policies/agents/trajectories/trajectory.jl similarity index 100% rename from src/components/trajectories/trajectory.jl rename to src/policies/agents/trajectories/trajectory.jl diff --git a/src/components/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl similarity index 81% rename from src/components/trajectories/trajectory_extension.jl rename to src/policies/agents/trajectories/trajectory_extension.jl index 8e67fd8..8701221 100644 --- a/src/components/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -39,27 +39,26 @@ struct UniformBatchSampler <: AbstractSampler end """ - sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=keys(trajectory)]) + sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=Val(keys(trajectory))]) !!! note Here we return a copy instead of a view: 1. Each sample is independent of the original `trajectory` so that `trajectory` can be updated async. 2. [Copy is not always so bad](https://docs.julialang.org/en/v1/manual/performance-tips/#Copying-data-is-not-always-bad). """ -function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler, traces=keys(t)) +function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler, traces=Val(keys(t))) sample(Random.GLOBAL_RNG, t, sampler, traces) end function StatsBase.sample(rng::AbstractRNG, t::CircularVectorSARTSATrajectory, s::UniformBatchSampler, traces) inds = rand(rng, 1:length(t), s.batch_size) - map(traces) do x - Flux.batch(view(t[x], inds)) - end + NamedTuple{traces}(Flux.batch(view(t[x], inds)) for x in traces) end -function StatsBase.sample(rng::AbstractRNG, t::CircularArraySARTTrajectory, s::UniformBatchSampler, traces) +function StatsBase.sample(rng::AbstractRNG, t::CircularArraySARTTrajectory, s::UniformBatchSampler, ::Val{SARTS}) inds = rand(rng, 1:length(t), s.batch_size) - map(traces) do x - convert(Array, consecutive_view(t[x], inds)) - end + NamedTuple{SARTS}(( + (convert(Array, consecutive_view(t[x], inds)) for x in SART)..., + convert(Array,consecutive_view(t[:state], inds.+1)) + )) end \ No newline at end of file diff --git a/src/policies/policies.jl b/src/policies/policies.jl new file mode 100644 index 0000000..ea4faf5 --- /dev/null +++ b/src/policies/policies.jl @@ -0,0 +1,2 @@ +include("q_based_policies/q_based_policies.jl") +include("agents/agents.jl") diff --git a/src/components/explorers/UCB_explorer.jl b/src/policies/q_based_policies/explorers/UCB_explorer.jl similarity index 100% rename from src/components/explorers/UCB_explorer.jl rename to src/policies/q_based_policies/explorers/UCB_explorer.jl diff --git a/src/components/explorers/abstract_explorer.jl b/src/policies/q_based_policies/explorers/abstract_explorer.jl similarity index 100% rename from src/components/explorers/abstract_explorer.jl rename to src/policies/q_based_policies/explorers/abstract_explorer.jl diff --git a/src/components/explorers/batch_explorer.jl b/src/policies/q_based_policies/explorers/batch_explorer.jl similarity index 100% rename from src/components/explorers/batch_explorer.jl rename to src/policies/q_based_policies/explorers/batch_explorer.jl diff --git a/src/components/explorers/epsilon_greedy_explorer.jl b/src/policies/q_based_policies/explorers/epsilon_greedy_explorer.jl similarity index 100% rename from src/components/explorers/epsilon_greedy_explorer.jl rename to src/policies/q_based_policies/explorers/epsilon_greedy_explorer.jl diff --git a/src/components/explorers/explorers.jl b/src/policies/q_based_policies/explorers/explorers.jl similarity index 100% rename from src/components/explorers/explorers.jl rename to src/policies/q_based_policies/explorers/explorers.jl diff --git a/src/components/explorers/gumbel_softmax_explorer.jl b/src/policies/q_based_policies/explorers/gumbel_softmax_explorer.jl similarity index 100% rename from src/components/explorers/gumbel_softmax_explorer.jl rename to src/policies/q_based_policies/explorers/gumbel_softmax_explorer.jl diff --git a/src/components/explorers/weighted_explorer.jl b/src/policies/q_based_policies/explorers/weighted_explorer.jl similarity index 100% rename from src/components/explorers/weighted_explorer.jl rename to src/policies/q_based_policies/explorers/weighted_explorer.jl diff --git a/src/components/explorers/weighted_softmax_explorer.jl b/src/policies/q_based_policies/explorers/weighted_softmax_explorer.jl similarity index 100% rename from src/components/explorers/weighted_softmax_explorer.jl rename to src/policies/q_based_policies/explorers/weighted_softmax_explorer.jl diff --git a/src/components/learners/abstract_learner.jl b/src/policies/q_based_policies/learners/abstract_learner.jl similarity index 100% rename from src/components/learners/abstract_learner.jl rename to src/policies/q_based_policies/learners/abstract_learner.jl diff --git a/src/components/approximators/abstract_approximator.jl b/src/policies/q_based_policies/learners/approximators/abstract_approximator.jl similarity index 100% rename from src/components/approximators/abstract_approximator.jl rename to src/policies/q_based_policies/learners/approximators/abstract_approximator.jl diff --git a/src/components/approximators/approximators.jl b/src/policies/q_based_policies/learners/approximators/approximators.jl similarity index 100% rename from src/components/approximators/approximators.jl rename to src/policies/q_based_policies/learners/approximators/approximators.jl diff --git a/src/components/approximators/neural_network_approximator.jl b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl similarity index 100% rename from src/components/approximators/neural_network_approximator.jl rename to src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl diff --git a/src/components/approximators/tabular_approximator.jl b/src/policies/q_based_policies/learners/approximators/tabular_approximator.jl similarity index 100% rename from src/components/approximators/tabular_approximator.jl rename to src/policies/q_based_policies/learners/approximators/tabular_approximator.jl diff --git a/src/components/learners/learners.jl b/src/policies/q_based_policies/learners/learners.jl similarity index 59% rename from src/components/learners/learners.jl rename to src/policies/q_based_policies/learners/learners.jl index 2f85ef6..f6e4c69 100644 --- a/src/components/learners/learners.jl +++ b/src/policies/q_based_policies/learners/learners.jl @@ -1,2 +1,3 @@ include("abstract_learner.jl") +include("approximators/approximators.jl") include("tabular_learner.jl") diff --git a/src/components/learners/tabular_learner.jl b/src/policies/q_based_policies/learners/tabular_learner.jl similarity index 100% rename from src/components/learners/tabular_learner.jl rename to src/policies/q_based_policies/learners/tabular_learner.jl diff --git a/src/policies/q_based_policies/q_based_policies.jl b/src/policies/q_based_policies/q_based_policies.jl new file mode 100644 index 0000000..b2704ec --- /dev/null +++ b/src/policies/q_based_policies/q_based_policies.jl @@ -0,0 +1,2 @@ +include("learners/learners.jl") +include("explorers/explorers.jl") \ No newline at end of file diff --git a/src/components/policies/Q_based_policy.jl b/src/policies/q_based_policies/q_based_policy.jl similarity index 92% rename from src/components/policies/Q_based_policy.jl rename to src/policies/q_based_policies/q_based_policy.jl index fea0cae..1703ceb 100644 --- a/src/components/policies/Q_based_policy.jl +++ b/src/policies/q_based_policies/q_based_policy.jl @@ -2,7 +2,7 @@ export QBasedPolicy, TabularRandomPolicy using MacroTools: @forward using Flux -using Setfield +using Setfield: @set """ QBasedPolicy(;learner::Q, explorer::S) @@ -29,7 +29,9 @@ RLBase.get_prob(p::QBasedPolicy, env, ::MinimalActionSet) = RLBase.get_prob(p::QBasedPolicy, env, ::FullActionSet) = get_prob(p.explorer, p.learner(env), get_legal_actions_mask(env)) -@forward QBasedPolicy.learner RLBase.get_priority, RLBase.update! +@forward QBasedPolicy.learner RLBase.get_priority + +RLBase.update!(p::QBasedPolicy, trajectory::AbstractTrajectory) = update!(p.learner, trajectory) function Flux.testmode!(p::QBasedPolicy, mode = true) testmode!(p.learner, mode) diff --git a/src/components/processors.jl b/src/utils/processors.jl similarity index 100% rename from src/components/processors.jl rename to src/utils/processors.jl diff --git a/src/utils/utils.jl b/src/utils/utils.jl index f350076..529c002 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -2,3 +2,4 @@ include("printing.jl") include("base.jl") include("device.jl") include("sum_tree.jl") +include("processors.jl") From d62b7affd59558852ad1d0717df87e6aee4c2174 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 2 Dec 2020 13:20:35 +0800 Subject: [PATCH 08/12] fix training/testing mode --- src/policies/agents/agent.jl | 9 ++++++--- src/policies/agents/agents.jl | 1 - .../agents/trajectories/trajectory_extension.jl | 16 +++++++++------- src/policies/{agents => }/base.jl | 13 ++++++++++--- src/policies/policies.jl | 3 ++- .../q_based_policies/explorers/UCB_explorer.jl | 2 -- .../explorers/abstract_explorer.jl | 5 ----- .../q_based_policies/explorers/batch_explorer.jl | 2 -- .../explorers/epsilon_greedy_explorer.jl | 2 -- .../learners/abstract_learner.jl | 4 ---- .../approximators/neural_network_approximator.jl | 12 +++++------- src/policies/q_based_policies/q_based_policy.jl | 10 +++++----- test/components/explorers.jl | 7 ------- 13 files changed, 37 insertions(+), 49 deletions(-) rename src/policies/{agents => }/base.jl (84%) diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 3371e62..15f486c 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -26,11 +26,14 @@ end functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy get_role(agent::Agent) = agent.role -mode(agent::Agent) = agent.mode -set_mode!(agent::Agent, mode::AbstractMode) = agent.mode = mode + +function set_mode!(agent::Agent, mode::AbstractMode) + agent.mode = mode + set_mode!(agent.policy, mode) +end (agent::Agent)(env) = agent.policy(env) -(agent::Agent)(stage::AbstractStage, env::AbstractEnv) = agent(env, stage, mode(agent)) +(agent::Agent)(stage::AbstractStage, env::AbstractEnv) = agent(env, stage, agent.mode) function (agent::Agent)(env::AbstractEnv, stage::AbstractStage, mode::AbstractMode) update!(agent.trajectory, agent.policy, env, stage, mode) diff --git a/src/policies/agents/agents.jl b/src/policies/agents/agents.jl index 005a535..3e8628f 100644 --- a/src/policies/agents/agents.jl +++ b/src/policies/agents/agents.jl @@ -1,3 +1,2 @@ -include("base.jl") include("trajectories/trajectories.jl") include("agent.jl") \ No newline at end of file diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index 8701221..fb42165 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -32,12 +32,14 @@ end # Samplers ##### -abstract type AbstractSampler end +abstract type AbstractSampler{traces} end -struct UniformBatchSampler <: AbstractSampler +struct UniformBatchSampler{traces} <: AbstractSampler{traces} batch_size::Int end +UniformBatchSampler(batch_size::Int) = UniformBatchSampler{SARTSA}(batch_size) + """ sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=Val(keys(trajectory))]) @@ -46,16 +48,16 @@ end 1. Each sample is independent of the original `trajectory` so that `trajectory` can be updated async. 2. [Copy is not always so bad](https://docs.julialang.org/en/v1/manual/performance-tips/#Copying-data-is-not-always-bad). """ -function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler, traces=Val(keys(t))) - sample(Random.GLOBAL_RNG, t, sampler, traces) +function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) + sample(Random.GLOBAL_RNG, t, sampler) end -function StatsBase.sample(rng::AbstractRNG, t::CircularVectorSARTSATrajectory, s::UniformBatchSampler, traces) +function StatsBase.sample(rng::AbstractRNG, t::CircularVectorSARTSATrajectory, s::UniformBatchSampler{SARTSA}) inds = rand(rng, 1:length(t), s.batch_size) - NamedTuple{traces}(Flux.batch(view(t[x], inds)) for x in traces) + NamedTuple{SARTSA}(Flux.batch(view(t[x], inds)) for x in SARTSA) end -function StatsBase.sample(rng::AbstractRNG, t::CircularArraySARTTrajectory, s::UniformBatchSampler, ::Val{SARTS}) +function StatsBase.sample(rng::AbstractRNG, t::CircularArraySARTTrajectory, s::UniformBatchSampler{SARTS}) inds = rand(rng, 1:length(t), s.batch_size) NamedTuple{SARTS}(( (convert(Array, consecutive_view(t[x], inds)) for x in SART)..., diff --git a/src/policies/agents/base.jl b/src/policies/base.jl similarity index 84% rename from src/policies/agents/base.jl rename to src/policies/base.jl index 0a52e04..c212e62 100644 --- a/src/policies/agents/base.jl +++ b/src/policies/base.jl @@ -12,7 +12,6 @@ export AbstractStage, PRE_ACT_STAGE, POST_ACT_STAGE, set_mode!, - mode, AbstractMode, TrainMode, TRAIN_MODE, @@ -60,6 +59,14 @@ const EVAL_MODE = EvalMode() struct TestMode <: AbstractMode end const TEST_MODE = TestMode() -function mode end +function set_mode!(p, ::TrainMode) + for x in Flux.trainable(p) + Flux.trainmode!(x) + end +end -function set_mode! end \ No newline at end of file +function set_mode!(p, ::Union{TestMode, EvalMode}) + for x in Flux.trainable(p) + Flux.testmode!(x) + end +end \ No newline at end of file diff --git a/src/policies/policies.jl b/src/policies/policies.jl index ea4faf5..4db539e 100644 --- a/src/policies/policies.jl +++ b/src/policies/policies.jl @@ -1,2 +1,3 @@ -include("q_based_policies/q_based_policies.jl") +include("base.jl") include("agents/agents.jl") +include("q_based_policies/q_based_policies.jl") diff --git a/src/policies/q_based_policies/explorers/UCB_explorer.jl b/src/policies/q_based_policies/explorers/UCB_explorer.jl index 3110754..4cb31d0 100644 --- a/src/policies/q_based_policies/explorers/UCB_explorer.jl +++ b/src/policies/q_based_policies/explorers/UCB_explorer.jl @@ -11,8 +11,6 @@ Base.@kwdef mutable struct UCBExplorer{R<:AbstractRNG} <: AbstractExplorer is_training::Bool = true end -Flux.testmode!(p::UCBExplorer, mode = true) = p.is_training = !mode - """ UCBExplorer(na; c=2.0, ϵ=1e-10, step=1, seed=nothing) diff --git a/src/policies/q_based_policies/explorers/abstract_explorer.jl b/src/policies/q_based_policies/explorers/abstract_explorer.jl index f2f7080..a460d9e 100644 --- a/src/policies/q_based_policies/explorers/abstract_explorer.jl +++ b/src/policies/q_based_policies/explorers/abstract_explorer.jl @@ -26,8 +26,3 @@ function RLBase.get_prob(p::AbstractExplorer, x) end Similart to `get_prob(p::AbstractExplorer, x)`, but here only the `mask`ed elements are considered. """ function RLBase.get_prob(p::AbstractExplorer, x, mask) end - -# see discussion https://github.com/hill-a/stable-baselines/issues/819 -Flux.testmode!(p::AbstractExplorer, mode = true) = - @warn "trainmode/testmode will not take effect on explorer, you may consider switching to GreedyExplorer in testmode" maxlog = - 1 diff --git a/src/policies/q_based_policies/explorers/batch_explorer.jl b/src/policies/q_based_policies/explorers/batch_explorer.jl index f81bd7d..13e0cb1 100644 --- a/src/policies/q_based_policies/explorers/batch_explorer.jl +++ b/src/policies/q_based_policies/explorers/batch_explorer.jl @@ -21,5 +21,3 @@ Apply inner explorer to each column of `values`. (x::BatchExplorer)(v::AbstractVector) = x.explorer(v) (x::BatchExplorer)(v::AbstractVector, m::AbstractVector) = x.explorer(v, m) - -Flux.testmode!(x::BatchExplorer, mode = true) = testmode!(x.explorer, mode) diff --git a/src/policies/q_based_policies/explorers/epsilon_greedy_explorer.jl b/src/policies/q_based_policies/explorers/epsilon_greedy_explorer.jl index f0e6b67..7ef9e42 100644 --- a/src/policies/q_based_policies/explorers/epsilon_greedy_explorer.jl +++ b/src/policies/q_based_policies/explorers/epsilon_greedy_explorer.jl @@ -72,8 +72,6 @@ function EpsilonGreedyExplorer(; ) end -Flux.testmode!(p::EpsilonGreedyExplorer, mode = true) = p.is_training = !mode - EpsilonGreedyExplorer(ϵ; kwargs...) = EpsilonGreedyExplorer(; ϵ_stable = ϵ, kwargs...) function get_ϵ(s::EpsilonGreedyExplorer{:linear}, step) diff --git a/src/policies/q_based_policies/learners/abstract_learner.jl b/src/policies/q_based_policies/learners/abstract_learner.jl index aa88c81..653064c 100644 --- a/src/policies/q_based_policies/learners/abstract_learner.jl +++ b/src/policies/q_based_policies/learners/abstract_learner.jl @@ -15,7 +15,3 @@ function (learner::AbstractLearner)(env) end get_priority(p::AbstractLearner, experience) """ function RLBase.get_priority(p::AbstractLearner, experience) end - -# TODO: deprecate this default function -Flux.testmode!(learner::AbstractLearner, mode = true) = - Flux.testmode!(learner.approximator, mode) diff --git a/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl index cde597e..e1c2b91 100644 --- a/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl +++ b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl @@ -18,7 +18,9 @@ Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator optimizer::O = nothing end -(app::NeuralNetworkApproximator)(x) = app.model(x) +function (app::NeuralNetworkApproximator)(x) + app.model(send_to_device(device(app.model), x)) +end functor(x::NeuralNetworkApproximator) = (model = x.model,), y -> NeuralNetworkApproximator(y.model, x.optimizer) @@ -31,8 +33,6 @@ RLBase.update!(app::NeuralNetworkApproximator, gs) = Base.copyto!(dest::NeuralNetworkApproximator, src::NeuralNetworkApproximator) = Flux.loadparams!(dest.model, params(src)) -Flux.testmode!(app::NeuralNetworkApproximator, mode = true) = testmode!(app.model, mode) - ##### # ActorCritic ##### @@ -53,7 +53,5 @@ functor(x::ActorCritic) = RLBase.update!(app::ActorCritic, gs) = Flux.Optimise.update!(app.optimizer, params(app), gs) -function Flux.testmode!(app::ActorCritic, mode = true) - testmode!(app.actor, mode) - testmode!(app.critic, mode) -end +actor(ac::ActorCritic) = x -> ac.actor(send_to_device(device(ac.actor), x)) +critic(ac::ActorCritic) = x -> ac.critic(send_to_device(device(ac.actor), x)) diff --git a/src/policies/q_based_policies/q_based_policy.jl b/src/policies/q_based_policies/q_based_policy.jl index 1703ceb..35cbe57 100644 --- a/src/policies/q_based_policies/q_based_policy.jl +++ b/src/policies/q_based_policies/q_based_policy.jl @@ -18,6 +18,11 @@ end Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learner +function set_mode!(p::QBasedPolicy, m::AbstractMode) + @warn "setting a `QBasedPolicy` to $m, the inner `explorer` will not be modified!" maxlog=1 + set_mode!(p.learner, m) +end + (π::QBasedPolicy)(env) = π(env, ActionStyle(env)) (π::QBasedPolicy)(env, ::MinimalActionSet) = get_actions(env)[π.explorer(π.learner(env))] (π::QBasedPolicy)(env, ::FullActionSet) = @@ -33,11 +38,6 @@ RLBase.get_prob(p::QBasedPolicy, env, ::FullActionSet) = RLBase.update!(p::QBasedPolicy, trajectory::AbstractTrajectory) = update!(p.learner, trajectory) -function Flux.testmode!(p::QBasedPolicy, mode = true) - testmode!(p.learner, mode) - testmode!(p.explorer, mode) -end - ##### # TabularRandomPolicy ##### diff --git a/test/components/explorers.jl b/test/components/explorers.jl index 09177aa..411120e 100644 --- a/test/components/explorers.jl +++ b/test/components/explorers.jl @@ -20,13 +20,6 @@ target_prob; atol = 0.005, )) - - testmode!(explorer) - @test explorer(values) == 3 - @test isapprox(probs(get_prob(explorer, values)), [0, 0, 1, 0]) - mask = Bool[1, 0, 0, 1] - @test explorer(values, mask) == 1 - @test isapprox(probs(get_prob(explorer, values, mask)), [1, 0, 0, 0]) end @testset "linear" begin From a05a7a68e0d92fc59dc6cee67a09205b6763ec55 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 2 Dec 2020 15:13:45 +0800 Subject: [PATCH 09/12] fix problems found in BasicDQN --- .../learners/approximators/neural_network_approximator.jl | 7 +------ src/policies/q_based_policies/q_based_policies.jl | 3 ++- src/utils/printing.jl | 5 ++--- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl index e1c2b91..fcca965 100644 --- a/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl +++ b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl @@ -18,9 +18,7 @@ Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator optimizer::O = nothing end -function (app::NeuralNetworkApproximator)(x) - app.model(send_to_device(device(app.model), x)) -end +(app::NeuralNetworkApproximator)(x) = app.model(x) functor(x::NeuralNetworkApproximator) = (model = x.model,), y -> NeuralNetworkApproximator(y.model, x.optimizer) @@ -52,6 +50,3 @@ functor(x::ActorCritic) = (actor = x.actor, critic = x.critic), y -> ActorCritic(y.actor, y.critic, x.optimizer) RLBase.update!(app::ActorCritic, gs) = Flux.Optimise.update!(app.optimizer, params(app), gs) - -actor(ac::ActorCritic) = x -> ac.actor(send_to_device(device(ac.actor), x)) -critic(ac::ActorCritic) = x -> ac.critic(send_to_device(device(ac.actor), x)) diff --git a/src/policies/q_based_policies/q_based_policies.jl b/src/policies/q_based_policies/q_based_policies.jl index b2704ec..048003a 100644 --- a/src/policies/q_based_policies/q_based_policies.jl +++ b/src/policies/q_based_policies/q_based_policies.jl @@ -1,2 +1,3 @@ include("learners/learners.jl") -include("explorers/explorers.jl") \ No newline at end of file +include("explorers/explorers.jl") +include("q_based_policy.jl") \ No newline at end of file diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 0a53316..dca8ba1 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -8,15 +8,14 @@ struct StructTree{X} x::X end -is_expand(x::T) where T = is_expand(T) -is_expand(::Type{T}) where T = true - +is_expand(x) = true is_expand(::AbstractArray) = false is_expand(::AbstractDict) = false is_expand(::AbstractRNG) = false is_expand(::Progress) = false is_expand(::Function) = false is_expand(::UnionAll) = false +is_expand(::DataType) = false function AT.children(t::StructTree{X}) where {X} if is_expand(t.x) From 52eb087786dfe21282ae20a46222123641d5ad8e Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 2 Dec 2020 20:43:13 +0800 Subject: [PATCH 10/12] improve trajectories --- .../trajectories/abstract_trajectory.jl | 4 +- .../trajectories/trajectory_extension.jl | 80 ++++++++++++++++--- 2 files changed, 72 insertions(+), 12 deletions(-) diff --git a/src/policies/agents/trajectories/abstract_trajectory.jl b/src/policies/agents/trajectories/abstract_trajectory.jl index 51a27cf..a6a6062 100644 --- a/src/policies/agents/trajectories/abstract_trajectory.jl +++ b/src/policies/agents/trajectories/abstract_trajectory.jl @@ -62,4 +62,6 @@ end const SART = (:state, :action, :reward, :terminal) const SARTS = (:state, :action, :reward, :terminal, :next_state) -const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) \ No newline at end of file +const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) +const SLARTSL = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask) +const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action) \ No newline at end of file diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index fb42165..9b673a8 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -1,4 +1,4 @@ -export NStepInserter, UniformBatchSampler +export NStepInserter, BatchSampler, NStepBatchSampler using Random @@ -34,12 +34,6 @@ end abstract type AbstractSampler{traces} end -struct UniformBatchSampler{traces} <: AbstractSampler{traces} - batch_size::Int -end - -UniformBatchSampler(batch_size::Int) = UniformBatchSampler{SARTSA}(batch_size) - """ sample([rng=Random.GLOBAL_RNG], trajectory, sampler, [traces=Val(keys(trajectory))]) @@ -52,15 +46,79 @@ function StatsBase.sample(t::AbstractTrajectory, sampler::AbstractSampler) sample(Random.GLOBAL_RNG, t, sampler) end -function StatsBase.sample(rng::AbstractRNG, t::CircularVectorSARTSATrajectory, s::UniformBatchSampler{SARTSA}) - inds = rand(rng, 1:length(t), s.batch_size) - NamedTuple{SARTSA}(Flux.batch(view(t[x], inds)) for x in SARTSA) +##### +## BatchSampler +##### + +struct BatchSampler{traces} <: AbstractSampler{traces} + batch_size::Int end -function StatsBase.sample(rng::AbstractRNG, t::CircularArraySARTTrajectory, s::UniformBatchSampler{SARTS}) +BatchSampler(batch_size::Int) = BatchSampler{SARTSA}(batch_size) + +function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::BatchSampler) inds = rand(rng, 1:length(t), s.batch_size) + sample(inds, t, s) +end + +function StatsBase.sample(inds::Vector{Int}, t::CircularVectorSARTSATrajectory, s::BatchSampler{traces}) where traces + NamedTuple{SARTSA}(Flux.batch(view(t[x], inds)) for x in traces) +end + +function StatsBase.sample(inds::Vector{Int}, t::CircularArraySARTTrajectory, s::BatchSampler{SARTS}) NamedTuple{SARTS}(( (convert(Array, consecutive_view(t[x], inds)) for x in SART)..., convert(Array,consecutive_view(t[:state], inds.+1)) )) +end + +##### +## NStepBatchSampler +##### + +Base.@kwdef struct NStepBatchSampler{traces} <: AbstractSampler{traces} + γ::Float32 + n::Int = 1 + batch_size::Int = 32 + stack_size::Union{Nothing,Int} = nothing +end + +function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler) + inds = rand(rng, 1:(length(t)-s.n+1), s.batch_size) + sample(inds, t, s) +end + +function StatsBase.sample(inds::Vector{Int}, traj::CircularArraySARTTrajectory, s::NStepBatchSampler{traces}) where traces + γ, n, bz, sz = s.γ, s.n, s.batch_size, s.stack_size + next_inds = inds .+ n + + s = convert(Array, consecutive_view(traj[:state], inds;n_stack = sz)) + a = convert(Array, consecutive_view(traj[:action], inds)) + s′ = convert(Array, consecutive_view(traj[:state], next_inds;n_stack = sz)) + + consecutive_rewards = consecutive_view(traj[:reward], inds; n_horizon = n) + consecutive_terminals = consecutive_view(traj[:terminal], inds; n_horizon = n) + r, t = zeros(Float32, bz), fill(false, bz) + + # make sure that we only consider experiences in current episode + for i in 1:bz + m = findfirst(view(consecutive_terminals, :, i)) + if isnothing(m) + t[i] = false + r[i] = discount_rewards_reduced(view(consecutive_rewards, :, i), γ) + else + t[i] = true + r[i] = discount_rewards_reduced(view(consecutive_rewards, 1:m, i), γ) + end + end + + if traces == SARTS + NamedTuple{SARTS}((s, a, r, t, s′)) + elseif traces == SLARTSL + l = convert(Array, consecutive_view(traj[:legal_actions_mask], inds)) + l′ = convert(Array, consecutive_view(traj[:next_legal_actions_mask], next_inds)) + NamedTuple{SLARTSL}((s, l, a, r, t, s′, l′)) + else + @error "unsupported traces $traces" + end end \ No newline at end of file From 704a96fe9b47f4d638546f1831e31dba86e1375f Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 3 Dec 2020 00:05:28 +0800 Subject: [PATCH 11/12] support PrioritizedTrajectory --- src/policies/agents/agent.jl | 2 +- .../agents/trajectories/trajectory.jl | 28 ++++++++++++++++++- .../trajectories/trajectory_extension.jl | 20 +++++++++++-- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 15f486c..03cbcfe 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -76,7 +76,7 @@ function RLBase.update!(trajectory::AbstractTrajectory, policy::AbstractPolicy, haskey(trajectory, :legal_actions_mask) && push!(trajectory[:legal_actions_mask], get_legal_actions_mask(env)) end -function RLBase.update!(trajectory::CircularArraySARTTrajectory, policy::AbstractPolicy, env::AbstractEnv, ::PreActStage, ::AbstractMode) +function RLBase.update!(trajectory::AbstractTrajectory, policy::AbstractPolicy, env::AbstractEnv, ::PreActStage, ::AbstractMode) action = policy(env) push!(trajectory[:state], get_state(env)) push!(trajectory[:action], action) diff --git a/src/policies/agents/trajectories/trajectory.jl b/src/policies/agents/trajectories/trajectory.jl index da782fc..047709d 100644 --- a/src/policies/agents/trajectories/trajectory.jl +++ b/src/policies/agents/trajectories/trajectory.jl @@ -1,4 +1,5 @@ export Trajectory, + PrioritizedTrajectory, DUMMY_TRAJECTORY, DummyTrajectory, CircularArrayTrajectory, @@ -6,6 +7,7 @@ export Trajectory, CircularArraySARTTrajectory, CircularVectorSARTTrajectory, CircularVectorSARTSATrajectory, + CircularArrayPSARTTrajectory, VectorTrajectory using MacroTools: @forward @@ -114,6 +116,30 @@ function VectorTrajectory(;kwargs...) end) end +##### + +Base.@kwdef struct PrioritizedTrajectory{P,T} <: AbstractTrajectory + priority::P + traj::T +end + +Base.keys(t::PrioritizedTrajectory) = (:priority, keys(t.traj)...) + +Base.length(t::PrioritizedTrajectory) = length(t.priority) + +Base.getindex(t::PrioritizedTrajectory, s::Symbol) = if s == :priority + t.priority +else + getindex(t.traj, s) +end + +const CircularArrayPSARTTrajectory = PrioritizedTrajectory{<:SumTree, <:CircularArraySARTTrajectory} + +CircularArrayPSARTTrajectory(;capacity, kwargs...) = PrioritizedTrajectory( + SumTree(capacity), + CircularArraySARTTrajectory(;capacity=capacity, kwargs...) +) + ##### # Common ##### @@ -121,4 +147,4 @@ end function Base.length(t::Union{<:CircularArraySARTTrajectory,<:CircularVectorSARTSATrajectory}) x = t[:terminal] size(x, ndims(x)) -end \ No newline at end of file +end diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index 9b673a8..c67b535 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -58,7 +58,7 @@ BatchSampler(batch_size::Int) = BatchSampler{SARTSA}(batch_size) function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::BatchSampler) inds = rand(rng, 1:length(t), s.batch_size) - sample(inds, t, s) + inds, sample(inds, t, s) end function StatsBase.sample(inds::Vector{Int}, t::CircularVectorSARTSATrajectory, s::BatchSampler{traces}) where traces @@ -85,7 +85,23 @@ end function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler) inds = rand(rng, 1:(length(t)-s.n+1), s.batch_size) - sample(inds, t, s) + inds, sample(inds, t, s) +end + +function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory{<:SumTree}, s::NStepBatchSampler) + bz, sz = s.batch_size, s.stack_size + inds = Vector{Int}(undef, bz) + priorities = Vector{Float32}(undef, bz) + valid_ind_range = isnothing(sz) ? (1:(length(t)-s.n+1)) : (sz:(length(t)-s.n+1)) + for i in 1:bz + ind, p = sample(rng, t.priority) + while ind ∉ valid_ind_range + ind, p = sample(rng, t.priority) + end + inds[i] = ind + priorities[i] = p + end + inds, (priority=priorities, sample(inds, t.traj, s)...) end function StatsBase.sample(inds::Vector{Int}, traj::CircularArraySARTTrajectory, s::NStepBatchSampler{traces}) where traces From bc31baafddda98dc0531d77ee2923ca6ccdc2047 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 3 Dec 2020 12:50:43 +0800 Subject: [PATCH 12/12] resolve comments --- Project.toml | 3 +-- src/extensions/ElasticArrays.jl | 11 ++++++++--- src/extensions/ReinforcementLearningBase.jl | 6 ------ src/policies/agents/agent.jl | 13 ++++--------- .../agents/trajectories/abstract_trajectory.jl | 4 +++- .../agents/trajectories/trajectory_extension.jl | 12 ++++++------ src/policies/base.jl | 13 ------------- .../q_based_policies/learners/abstract_learner.jl | 2 ++ src/policies/q_based_policies/q_based_policy.jl | 5 ----- 9 files changed, 24 insertions(+), 45 deletions(-) diff --git a/Project.toml b/Project.toml index be3e820..3c1b069 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.5.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" @@ -31,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractTrees = "0.3" Adapt = "2" -BSON = "0.2" +CircularArrayBuffers = "0.1" CUDA = "1, 2.1" Distributions = "0.24" ElasticArrays = "1.2" diff --git a/src/extensions/ElasticArrays.jl b/src/extensions/ElasticArrays.jl index d3c9037..25c8cfc 100644 --- a/src/extensions/ElasticArrays.jl +++ b/src/extensions/ElasticArrays.jl @@ -5,7 +5,12 @@ Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x]) Base.empty!(a::ElasticArray) = ElasticArrays.resize_lastdim!(a, 0) function Base.pop!(a::ElasticArray) - last_frame = select_last_frame(a) #= |> copy =# # !!! ensure that we will not access invalid data - ElasticArrays.resize!(a.data, length(a.data) - a.kernel_length.divisor) - last_frame + if length(a) > 0 + last_frame_inds = length(a.data) - a.kernel_length.divisor + 1 : length(a.data) + d = reshape(view(a.data, last_frame_inds), a.kernel_size) + ElasticArrays.resize!(a.data, length(a.data) - a.kernel_length.divisor) + d + else + @error "can not pop! from an empty ElasticArray" + end end diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index a4eff67..8f3c68c 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -2,7 +2,6 @@ using CUDA using Distributions: pdf using Random using Flux -using BSON using AbstractTrees RLBase.update!(p::RandomPolicy, x) = nothing @@ -13,8 +12,3 @@ Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10)) is_expand(::AbstractEnv) = false - -AbstractTrees.printnode(io::IO, t::StructTree{<:AbstractEnv}) = print( - io, - "$(RLBase.get_name(t.x)): $(join([f(t.x) for f in RLBase.get_env_traits()], ","))", -) diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 03cbcfe..f9a9c55 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -16,7 +16,7 @@ update the trajectory and policy appropriately in different stages and modes. - `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment - `role=RLBase.DEFAULT_PLAYER`: used to distinguish different agents """ -Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractPolicy +Base.@kwdef struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractPolicy policy::P trajectory::T = DUMMY_TRAJECTORY role::R = RLBase.DEFAULT_PLAYER @@ -27,11 +27,6 @@ functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy get_role(agent::Agent) = agent.role -function set_mode!(agent::Agent, mode::AbstractMode) - agent.mode = mode - set_mode!(agent.policy, mode) -end - (agent::Agent)(env) = agent.policy(env) (agent::Agent)(stage::AbstractStage, env::AbstractEnv) = agent(env, stage, agent.mode) @@ -40,7 +35,7 @@ function (agent::Agent)(env::AbstractEnv, stage::AbstractStage, mode::AbstractMo update!(agent.policy, agent.trajectory, env, stage, mode) end -## TrainMode +## TrainMode: update both policy and trajectory function (agent::Agent)(env::AbstractEnv, stage::PreActStage, mode::TrainMode) action = update!(agent.trajectory, agent.policy, env, stage, mode) @@ -48,13 +43,13 @@ function (agent::Agent)(env::AbstractEnv, stage::PreActStage, mode::TrainMode) action end -## EvalMode +## EvalMode: upate only trajectory function (agent::Agent)(env::AbstractEnv, stage::PreActStage, mode::EvalMode) update!(agent.trajectory, agent.policy, env, stage, mode) end -## TestMode +## TestMode: do not update (agent::Agent)(::AbstractEnv, ::AbstractStage, ::TestMode) = nothing (agent::Agent)(env::AbstractEnv, ::PreActStage, ::TestMode) = agent.policy(env) diff --git a/src/policies/agents/trajectories/abstract_trajectory.jl b/src/policies/agents/trajectories/abstract_trajectory.jl index a6a6062..078328f 100644 --- a/src/policies/agents/trajectories/abstract_trajectory.jl +++ b/src/policies/agents/trajectories/abstract_trajectory.jl @@ -1,7 +1,9 @@ export AbstractTrajectory, SART, SARTS, - SARTSA + SARTSA, + SLARTSL, + SLARTSLA """ AbstractTrajectory diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index c67b535..66ae3e2 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -58,14 +58,14 @@ BatchSampler(batch_size::Int) = BatchSampler{SARTSA}(batch_size) function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::BatchSampler) inds = rand(rng, 1:length(t), s.batch_size) - inds, sample(inds, t, s) + inds, select(inds, t, s) end -function StatsBase.sample(inds::Vector{Int}, t::CircularVectorSARTSATrajectory, s::BatchSampler{traces}) where traces +function select(inds::Vector{Int}, t::CircularVectorSARTSATrajectory, s::BatchSampler{traces}) where traces NamedTuple{SARTSA}(Flux.batch(view(t[x], inds)) for x in traces) end -function StatsBase.sample(inds::Vector{Int}, t::CircularArraySARTTrajectory, s::BatchSampler{SARTS}) +function select(inds::Vector{Int}, t::CircularArraySARTTrajectory, s::BatchSampler{SARTS}) NamedTuple{SARTS}(( (convert(Array, consecutive_view(t[x], inds)) for x in SART)..., convert(Array,consecutive_view(t[:state], inds.+1)) @@ -85,7 +85,7 @@ end function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler) inds = rand(rng, 1:(length(t)-s.n+1), s.batch_size) - inds, sample(inds, t, s) + inds, select(inds, t, s) end function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory{<:SumTree}, s::NStepBatchSampler) @@ -101,10 +101,10 @@ function StatsBase.sample(rng::AbstractRNG, t::PrioritizedTrajectory{<:SumTree}, inds[i] = ind priorities[i] = p end - inds, (priority=priorities, sample(inds, t.traj, s)...) + inds, (priority=priorities, select(inds, t.traj, s)...) end -function StatsBase.sample(inds::Vector{Int}, traj::CircularArraySARTTrajectory, s::NStepBatchSampler{traces}) where traces +function select(inds::Vector{Int}, traj::CircularArraySARTTrajectory, s::NStepBatchSampler{traces}) where traces γ, n, bz, sz = s.γ, s.n, s.batch_size, s.stack_size next_inds = inds .+ n diff --git a/src/policies/base.jl b/src/policies/base.jl index c212e62..e2b9dc7 100644 --- a/src/policies/base.jl +++ b/src/policies/base.jl @@ -11,7 +11,6 @@ export AbstractStage, POST_EPISODE_STAGE, PRE_ACT_STAGE, POST_ACT_STAGE, - set_mode!, AbstractMode, TrainMode, TRAIN_MODE, @@ -58,15 +57,3 @@ const EVAL_MODE = EvalMode() struct TestMode <: AbstractMode end const TEST_MODE = TestMode() - -function set_mode!(p, ::TrainMode) - for x in Flux.trainable(p) - Flux.trainmode!(x) - end -end - -function set_mode!(p, ::Union{TestMode, EvalMode}) - for x in Flux.trainable(p) - Flux.testmode!(x) - end -end \ No newline at end of file diff --git a/src/policies/q_based_policies/learners/abstract_learner.jl b/src/policies/q_based_policies/learners/abstract_learner.jl index 653064c..24fda21 100644 --- a/src/policies/q_based_policies/learners/abstract_learner.jl +++ b/src/policies/q_based_policies/learners/abstract_learner.jl @@ -15,3 +15,5 @@ function (learner::AbstractLearner)(env) end get_priority(p::AbstractLearner, experience) """ function RLBase.get_priority(p::AbstractLearner, experience) end + +Base.show(io::IO, p::AbstractLearner) = AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10)) \ No newline at end of file diff --git a/src/policies/q_based_policies/q_based_policy.jl b/src/policies/q_based_policies/q_based_policy.jl index 35cbe57..ae6631e 100644 --- a/src/policies/q_based_policies/q_based_policy.jl +++ b/src/policies/q_based_policies/q_based_policy.jl @@ -18,11 +18,6 @@ end Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learner -function set_mode!(p::QBasedPolicy, m::AbstractMode) - @warn "setting a `QBasedPolicy` to $m, the inner `explorer` will not be modified!" maxlog=1 - set_mode!(p.learner, m) -end - (π::QBasedPolicy)(env) = π(env, ActionStyle(env)) (π::QBasedPolicy)(env, ::MinimalActionSet) = get_actions(env)[π.explorer(π.learner(env))] (π::QBasedPolicy)(env, ::FullActionSet) =