diff --git a/src/core/run.jl b/src/core/run.jl index 6d4374c..3ae0d50 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -34,7 +34,9 @@ function _run( hook(PRE_EPISODE_STAGE, policy, env) while !is_terminated(env) # one episode - action = policy(PRE_ACT_STAGE, env) + action = policy(env) + + policy(PRE_ACT_STAGE, env, action) hook(PRE_ACT_STAGE, policy, env, action) env(action) diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 817c6c9..8e45d86 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -31,16 +31,25 @@ function check(agent::Agent, env::AbstractEnv) check(agent.policy, env) end +##### +# Default behaviors +##### + """ Here we extend the definition of `(p::AbstractPolicy)(::AbstractEnv)` in `RLBase` to accept an `AbstractStage` as the first argument. Algorithm designers -may customize these behaviors respectively. The default behaviors are: +may customize these behaviors respectively by implementing: + +- `(p::YourPolicy)(::AbstractStage, ::AbstractEnv)` +- `(p::YourPolicy)(::PreActStage, ::AbstractEnv, action)` + +The default behaviors for `Agent` are: + 1. Update the inner `trajectory` given the context of `policy`, `env`, and `stage`. 1. By default we do nothing. - 2. In `PreActStage`, we `push!` the current **state** of the `env` and the - **action** generated by `policy(env)` into the `trajectory`. And the - **action** is returned. + 2. In `PreActStage`, we `push!` the current **state** and the **action** into + the `trajectory`. 3. In `PostActStage`, we query the `reward` and `is_terminated` info from `env` and push them into `trajectory`. 4. For `CircularSARTTrajectory`: @@ -58,39 +67,33 @@ function (agent::Agent)(stage::AbstractStage, env::AbstractEnv) update!(agent.policy, agent.trajectory, env, stage) end -function (agent::Agent)(stage::PreActStage, env::AbstractEnv) - action = update!(agent.trajectory, agent.policy, env, stage) +function (agent::Agent)(stage::PreActStage, env::AbstractEnv, action) + update!(agent.trajectory, agent.policy, env, stage, action) update!(agent.policy, agent.trajectory, env, stage) - action end -RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage) = - nothing +function RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage) +end -RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreActStage) = +function RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreActStage) update!(p, t) +end -RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) = - nothing +function RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) +end -function RLBase.update!( - trajectory::Union{ - CircularArraySARTTrajectory, - PrioritizedTrajectory{<:CircularArraySARTTrajectory}, - }, - ::AbstractPolicy, - ::AbstractEnv, - ::PreEpisodeStage, -) - if length(trajectory) > 0 - pop!(trajectory[:state]) - pop!(trajectory[:action]) - end +function RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::PreActStage, action) end +##### +# Default behaviors for known trajectories +##### + function RLBase.update!( trajectory::Union{ + CircularArraySARTTrajectory, CircularArraySLARTTrajectory, + PrioritizedTrajectory{<:CircularArraySARTTrajectory}, PrioritizedTrajectory{<:CircularArraySLARTTrajectory}, }, ::AbstractPolicy, @@ -100,39 +103,53 @@ function RLBase.update!( if length(trajectory) > 0 pop!(trajectory[:state]) pop!(trajectory[:action]) - pop!(trajectory[:legal_actions_mask]) + if haskey(trajectory, :legal_actions_mask) + pop!(trajectory[:legal_actions_mask]) + end end end function RLBase.update!( trajectory::Union{ CircularArraySARTTrajectory, + CircularArraySLARTTrajectory, PrioritizedTrajectory{<:CircularArraySARTTrajectory}, + PrioritizedTrajectory{<:CircularArraySLARTTrajectory}, }, policy::AbstractPolicy, env::AbstractEnv, - ::Union{PreActStage,PostEpisodeStage}, + ::PreActStage, + action ) - action = policy(env) push!(trajectory[:state], state(env)) push!(trajectory[:action], action) - action + if haskey(trajectory, :legal_actions_mask) + push!(trajectory[:legal_actions_mask], legal_action_space_mask(env)) + end end function RLBase.update!( trajectory::Union{ + CircularArraySARTTrajectory, CircularArraySLARTTrajectory, + PrioritizedTrajectory{<:CircularArraySARTTrajectory}, PrioritizedTrajectory{<:CircularArraySLARTTrajectory}, }, policy::AbstractPolicy, env::AbstractEnv, - ::Union{PreActStage,PostEpisodeStage}, + ::PostEpisodeStage, ) - action = policy(env) + # Note that for trajectories like `CircularArraySARTTrajectory`, data are + # stored in a SARSA format, which means we still need to generate a dummy + # action at the end of an episode. Here we simply select a random one using + # the global rng. In theory it shouldn't affect the performance of specific algorithm. + action = rand(action_space(env)) + push!(trajectory[:state], state(env)) push!(trajectory[:action], action) - push!(trajectory[:legal_actions_mask], legal_action_space_mask(env)) - action + if haskey(trajectory, :legal_actions_mask) + push!(trajectory[:legal_actions_mask], legal_action_space_mask(env)) + end end function RLBase.update!( diff --git a/src/policies/base.jl b/src/policies/base.jl index c2439b8..618e66c 100644 --- a/src/policies/base.jl +++ b/src/policies/base.jl @@ -37,4 +37,4 @@ struct PostActStage <: AbstractStage end const POST_ACT_STAGE = PostActStage() (p::AbstractPolicy)(::AbstractStage, ::AbstractEnv) = nothing -(p::AbstractPolicy)(::PreActStage, env::AbstractEnv) = p(env) +(p::AbstractPolicy)(::AbstractStage, ::AbstractEnv, action) = nothing