diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 4d2cfde..817c6c9 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -31,10 +31,28 @@ function check(agent::Agent, env::AbstractEnv) check(agent.policy, env) end -##### -# update! -##### - +""" +Here we extend the definition of `(p::AbstractPolicy)(::AbstractEnv)` in +`RLBase` to accept an `AbstractStage` as the first argument. Algorithm designers +may customize these behaviors respectively. The default behaviors are: +1. Update the inner `trajectory` given the context of `policy`, `env`, and + `stage`. + 1. By default we do nothing. + 2. In `PreActStage`, we `push!` the current **state** of the `env` and the + **action** generated by `policy(env)` into the `trajectory`. And the + **action** is returned. + 3. In `PostActStage`, we query the `reward` and `is_terminated` info from + `env` and push them into `trajectory`. + 4. For `CircularSARTTrajectory`: + 1. In the `PosEpisodeStage`, we push the `state` at the end of an episode + and a dummy action into the `trajectory`. + 1. In the `PreEpisodeStage`, we pop out the lastest `state` and `action` + pair (which are dummy ones) from `trajectory`. +2. Update the inner `policy` given the context of `trajectory`, `env`, and + `stage`. + 1. By default, we only `update!` the `policy` in the `PreActStage`. And it's + despatched to `update!(policy, trajectory)`. +""" function (agent::Agent)(stage::AbstractStage, env::AbstractEnv) update!(agent.trajectory, agent.policy, env, stage) update!(agent.policy, agent.trajectory, env, stage) @@ -48,11 +66,10 @@ end RLBase.update!(::AbstractPolicy, ::AbstractTrajectory, ::AbstractEnv, ::AbstractStage) = nothing + RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreActStage) = update!(p, t) -## update trajectory - RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) = nothing