diff --git a/src/core/run.jl b/src/core/run.jl index 39ef1f3..37a14e2 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -15,33 +15,25 @@ function _run( hook::AbstractHook = EmptyHook(), ) - reset!(env) - agent(PRE_EPISODE_STAGE, env) - hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE, env) - hook(PRE_ACT_STAGE, agent, env, action) + while true # run episodes forever + reset!(env) + agent(PRE_EPISODE_STAGE, env) + hook(PRE_EPISODE_STAGE, agent, env) - while true - env(action) - agent(POST_ACT_STAGE, env) - hook(POST_ACT_STAGE, agent, env) + while !get_terminal(env) # one episode + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) - if get_terminal(env) - agent(POST_EPISODE_STAGE, env) # let the agent see the last observation - hook(POST_EPISODE_STAGE, agent, env) + env(action) - stop_condition(agent, env) && break + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) - reset!(env) - agent(PRE_EPISODE_STAGE, env) - hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE, env) - hook(PRE_ACT_STAGE, agent, env, action) - else - stop_condition(agent, env) && break - action = agent(PRE_ACT_STAGE, env) - hook(PRE_ACT_STAGE, agent, env, action) - end + stop_condition(agent, env) && return hook # early stop + end # end of an episode + + agent(POST_EPISODE_STAGE, env) # let the agent see the last observation + hook(POST_EPISODE_STAGE, agent, env) end hook end @@ -84,52 +76,34 @@ function _run( hooks = Dict(get_role(agent) => hook for (agent, hook) in zip(agents, hooks)) agents = Dict(get_role(agent) => agent for agent in agents) - reset!(env) - - agent = agents[get_current_player(env)] - hook = hooks[get_current_player(env)] - for p in get_players(env) - agents[p](PRE_EPISODE_STAGE, env) - hooks[p](PRE_EPISODE_STAGE, agents[p], env) - end - - action = agent(PRE_ACT_STAGE, env) - hook(PRE_ACT_STAGE, agent, env, action) - - while true - env(action) - agent(POST_ACT_STAGE, env) - hook(POST_ACT_STAGE, agent, env) - - if get_terminal(env) - for p in get_players(env) - agents[p](POST_EPISODE_STAGE, env) - hooks[p](POST_EPISODE_STAGE, agents[p], env) - end - - stop_condition(agent, env) && break - - reset!(env) + while true # run episodes forever + reset!(env) - for p in get_players(env) - agents[p](PRE_EPISODE_STAGE, env) - hooks[p](PRE_EPISODE_STAGE, agents[p], env) - end + for p in get_players(env) + agents[p](PRE_EPISODE_STAGE, env) + hooks[p](PRE_EPISODE_STAGE, agents[p], env) + end + while !get_terminal(env) # one episode agent = agents[get_current_player(env)] hook = hooks[get_current_player(env)] - action = agent(PRE_ACT_STAGE, env) - hook(PRE_ACT_STAGE, agent, env, action) - else - stop_condition(agent, env) && break - agent = agents[get_current_player(env)] - hook = hooks[get_current_player(env)] action = agent(PRE_ACT_STAGE, env) hook(PRE_ACT_STAGE, agent, env, action) + + env(action) + + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) + + stop_condition(agent, env) && return hooks # early stop + end # end of an episode + + for p in get_players(env) + agents[p](POST_EPISODE_STAGE, env) + hooks[p](POST_EPISODE_STAGE, agents[p], env) end end - hooks end