Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/policies/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,28 @@ function check(agent::Agent, env::AbstractEnv)
check(agent.policy, env)
end

#####
# update!
#####

"""
Here we extend the definition of `(p::AbstractPolicy)(::AbstractEnv)` in
`RLBase` to accept an `AbstractStage` as the first argument. Algorithm designers
may customize these behaviors respectively. The default behaviors are:
1. Update the inner `trajectory` given the context of `policy`, `env`, and
`stage`.
1. By default we do nothing.
2. In `PreActStage`, we `push!` the current **state** of the `env` and the
**action** generated by `policy(env)` into the `trajectory`. And the
**action** is returned.
3. In `PostActStage`, we query the `reward` and `is_terminated` info from
`env` and push them into `trajectory`.
4. For `CircularSARTTrajectory`:
1. In the `PosEpisodeStage`, we push the `state` at the end of an episode
and a dummy action into the `trajectory`.
1. In the `PreEpisodeStage`, we pop out the lastest `state` and `action`
pair (which are dummy ones) from `trajectory`.
2. Update the inner `policy` given the context of `trajectory`, `env`, and
`stage`.
1. By default, we only `update!` the `policy` in the `PreActStage`. And it's
despatched to `update!(policy, trajectory)`.
"""
function (agent::Agent)(stage::AbstractStage, env::AbstractEnv)
update!(agent.trajectory, agent.policy, env, stage)
update!(agent.policy, agent.trajectory, env, stage)
Expand All @@ -48,11 +66,10 @@ end

RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage) =
nothing

RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreActStage) =
update!(p, t)

## update trajectory

RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) =
nothing

Expand Down