Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ function _run(
hook(PRE_EPISODE_STAGE, policy, env)

while !is_terminated(env) # one episode
action = policy(PRE_ACT_STAGE, env)
action = policy(env)

policy(PRE_ACT_STAGE, env, action)
hook(PRE_ACT_STAGE, policy, env, action)

env(action)
Expand Down
83 changes: 50 additions & 33 deletions src/policies/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@ function check(agent::Agent, env::AbstractEnv)
check(agent.policy, env)
end

#####
# Default behaviors
#####

"""
Here we extend the definition of `(p::AbstractPolicy)(::AbstractEnv)` in
`RLBase` to accept an `AbstractStage` as the first argument. Algorithm designers
may customize these behaviors respectively. The default behaviors are:
may customize these behaviors respectively by implementing:

- `(p::YourPolicy)(::AbstractStage, ::AbstractEnv)`
- `(p::YourPolicy)(::PreActStage, ::AbstractEnv, action)`

The default behaviors for `Agent` are:

1. Update the inner `trajectory` given the context of `policy`, `env`, and
`stage`.
1. By default we do nothing.
2. In `PreActStage`, we `push!` the current **state** of the `env` and the
**action** generated by `policy(env)` into the `trajectory`. And the
**action** is returned.
2. In `PreActStage`, we `push!` the current **state** and the **action** into
the `trajectory`.
3. In `PostActStage`, we query the `reward` and `is_terminated` info from
`env` and push them into `trajectory`.
4. For `CircularSARTTrajectory`:
Expand All @@ -58,39 +67,33 @@ function (agent::Agent)(stage::AbstractStage, env::AbstractEnv)
update!(agent.policy, agent.trajectory, env, stage)
end

function (agent::Agent)(stage::PreActStage, env::AbstractEnv)
action = update!(agent.trajectory, agent.policy, env, stage)
function (agent::Agent)(stage::PreActStage, env::AbstractEnv, action)
update!(agent.trajectory, agent.policy, env, stage, action)
update!(agent.policy, agent.trajectory, env, stage)
action
end

RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage) =
nothing
function RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage)
end

RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreActStage) =
function RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreActStage)
update!(p, t)
end

RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) =
nothing
function RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage)
end

function RLBase.update!(
trajectory::Union{
CircularArraySARTTrajectory,
PrioritizedTrajectory{<:CircularArraySARTTrajectory},
},
::AbstractPolicy,
::AbstractEnv,
::PreEpisodeStage,
)
if length(trajectory) > 0
pop!(trajectory[:state])
pop!(trajectory[:action])
end
function RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::PreActStage, action)
end

#####
# Default behaviors for known trajectories
#####

function RLBase.update!(
trajectory::Union{
CircularArraySARTTrajectory,
CircularArraySLARTTrajectory,
PrioritizedTrajectory{<:CircularArraySARTTrajectory},
PrioritizedTrajectory{<:CircularArraySLARTTrajectory},
},
::AbstractPolicy,
Expand All @@ -100,39 +103,53 @@ function RLBase.update!(
if length(trajectory) > 0
pop!(trajectory[:state])
pop!(trajectory[:action])
pop!(trajectory[:legal_actions_mask])
if haskey(trajectory, :legal_actions_mask)
pop!(trajectory[:legal_actions_mask])
end
end
end

function RLBase.update!(
trajectory::Union{
CircularArraySARTTrajectory,
CircularArraySLARTTrajectory,
PrioritizedTrajectory{<:CircularArraySARTTrajectory},
PrioritizedTrajectory{<:CircularArraySLARTTrajectory},
},
policy::AbstractPolicy,
env::AbstractEnv,
::Union{PreActStage,PostEpisodeStage},
::PreActStage,
action
)
action = policy(env)
push!(trajectory[:state], state(env))
push!(trajectory[:action], action)
action
if haskey(trajectory, :legal_actions_mask)
push!(trajectory[:legal_actions_mask], legal_action_space_mask(env))
end
end

function RLBase.update!(
trajectory::Union{
CircularArraySARTTrajectory,
CircularArraySLARTTrajectory,
PrioritizedTrajectory{<:CircularArraySARTTrajectory},
PrioritizedTrajectory{<:CircularArraySLARTTrajectory},
},
policy::AbstractPolicy,
env::AbstractEnv,
::Union{PreActStage,PostEpisodeStage},
::PostEpisodeStage,
)
action = policy(env)
# Note that for trajectories like `CircularArraySARTTrajectory`, data are
# stored in a SARSA format, which means we still need to generate a dummy
# action at the end of an episode. Here we simply select a random one using
# the global rng. In theory it shouldn't affect the performance of specific algorithm.
action = rand(action_space(env))

push!(trajectory[:state], state(env))
push!(trajectory[:action], action)
push!(trajectory[:legal_actions_mask], legal_action_space_mask(env))
action
if haskey(trajectory, :legal_actions_mask)
push!(trajectory[:legal_actions_mask], legal_action_space_mask(env))
end
end

function RLBase.update!(
Expand Down
2 changes: 1 addition & 1 deletion src/policies/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ struct PostActStage <: AbstractStage end
const POST_ACT_STAGE = PostActStage()

(p::AbstractPolicy)(::AbstractStage, ::AbstractEnv) = nothing
(p::AbstractPolicy)(::PreActStage, env::AbstractEnv) = p(env)
(p::AbstractPolicy)(::AbstractStage, ::AbstractEnv, action) = nothing