Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/components/learners/tabular_learner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct TabularLearner{S,T} <: AbstractPolicy
end

TabularLearner() = TabularLearner{Int,Float32}()
TabularLearner{S}() = TabularLearner{S,Float32}()
TabularLearner{S,T}() = TabularLearner(Dict{S,Vector{T}}())
TabularLearner{S}() where S = TabularLearner{S,Float32}()
TabularLearner{S,T}() where {S,T} = TabularLearner(Dict{S,Vector{T}}())

function (p::TabularLearner)(env::AbstractEnv)
s = get_state(env)
Expand All @@ -30,5 +30,5 @@ function (p::TabularLearner)(env::AbstractEnv)
end
end

update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience)
RLBase.update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience)

26 changes: 13 additions & 13 deletions src/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook
reward::Float64 = 0.0
end

function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env)
hook.reward += get_reward(env)
end

function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv)
hook.reward += get_reward(env.env)
end

function (hook::TotalRewardPerEpisode)(
::Union{PostEpisodeStage,PostExperimentStage},
agent,
env,
)
(hook::TotalRewardPerEpisode)(s::AbstractStage, agent, env) = hook(s, agent, env, RewardStyle(env), NumAgentStyle(env))
(hook::TotalRewardPerEpisode)(::AbstractStage, agent, env, ::Any, ::Any) = nothing

(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::TerminalReward, ::SingleAgent) = push!(hook.rewards, get_reward(env))
(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::TerminalReward, ::MultiAgent) = push!(hook.rewards, get_reward(env, get_role(agent)))
(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::SingleAgent) = hook.reward += get_reward(env)
(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::MultiAgent) = hook.reward += get_reward(env, get_role(agent))
(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env::RewardOverriddenEnv, ::TerminalReward, ::SingleAgent) = push!(hook.rewards, get_reward(env.env))
(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env::RewardOverriddenEnv, ::TerminalReward, ::MultiAgent) = push!(hook.rewards, get_reward(env.env, get_role(agent)))
(hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::StepReward, ::SingleAgent) = hook.reward += get_reward(env.env)
(hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::StepReward, ::MultiAgent) = hook.reward += get_reward(env.env, get_role(agent))

function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::StepReward, ::Any)
push!(hook.rewards, hook.reward)
hook.reward = 0
end
Expand Down
108 changes: 68 additions & 40 deletions src/core/run.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export expected_policy_values

import Base: run

run(agent, env::AbstractEnv, args...) =
Expand Down Expand Up @@ -79,59 +81,85 @@ function run(
)
@assert length(agents) == get_num_players(env)

hooks = Dict(get_role(agent) => hook for (agent, hook) in zip(agents, hooks))
agents = Dict(get_role(agent) => agent for agent in agents)
reset!(env)
valid_action = rand(get_actions(env)) # init with a dummy value

agent = agents[get_current_player(env)]
hook = hooks[get_current_player(env)]

# async here?
for (agent, hook) in zip(agents, hooks)
agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(PRE_EPISODE_STAGE, agent, env)
action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(PRE_ACT_STAGE, agent, env, action)
# for Sequential environments, only one action is valid
if get_current_player(env) == get_role(agent)
valid_action = action
end
for p in get_players(env)
agents[p](PRE_EPISODE_STAGE, env)
hooks[p](PRE_EPISODE_STAGE, agents[p], env)
end

while true
env(valid_action)
action = agent(PRE_ACT_STAGE, env)
hook(PRE_ACT_STAGE, agent, env, action)

for (agent, hook) in zip(agents, hooks)
agent(POST_ACT_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(POST_ACT_STAGE, agent, env)
end
while true
env(action)
agent(POST_ACT_STAGE, env)
hook(POST_ACT_STAGE, agent, env)

if get_terminal(env)
for (agent, hook) in zip(agents, hooks)
agent(POST_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(POST_EPISODE_STAGE, agent, env)
for p in get_players(env)
agents[p](POST_EPISODE_STAGE, env)
hooks[p](POST_EPISODE_STAGE, agents[p], env)
end

stop_condition(agents, env) && break
stop_condition(agent, env) && break

reset!(env)
# async here?
for (agent, hook) in zip(agents, hooks)
agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(PRE_EPISODE_STAGE, agent, env)
action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(PRE_ACT_STAGE, agent, env, action)
# for Sequential environments, only one action is valid
if get_current_player(env) == get_role(agent)
valid_action = action
end

for p in get_players(env)
agents[p](PRE_EPISODE_STAGE, env)
hooks[p](PRE_EPISODE_STAGE, agents[p], env)
end

agent = agents[get_current_player(env)]
hook = hooks[get_current_player(env)]
action = agent(PRE_ACT_STAGE, env)
hook(PRE_ACT_STAGE, agent, env, action)
else
stop_condition(agents, env) && break
for (agent, hook) in zip(agents, hooks)
action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent)))
hook(PRE_ACT_STAGE, agent, env, action)
# for Sequential environments, only one action is valid
if get_current_player(env) == get_role(agent)
valid_action = action
end
end
stop_condition(agent, env) && break

agent = agents[get_current_player(env)]
hook = hooks[get_current_player(env)]
action = agent(PRE_ACT_STAGE, env)
hook(PRE_ACT_STAGE, agent, env, action)
end
end

hooks
end

"""
expected_policy_values(agents, env)

Calculate the expected return of each agent.
"""
expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv) = expected_policy_values(Dict(get_role(agent) => agent for agent in agents), env)

expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values(agents, env, RewardStyle(env), ChanceStyle(env), DynamicStyle(env))

function expected_policy_values(agents::Dict, env::AbstractEnv, ::TerminalReward, ::Union{ExplicitStochastic,Deterministic}, ::Sequential)
if get_terminal(env)
[get_reward(env, get_role(agent)) for agent in values(agents)]
elseif get_current_player(env) == get_chance_player(env)
vals = zeros(length(agents))
for a::ActionProbPair in get_legal_actions(env)
vals .+= a.prob .* expected_policy_values(agents, child(env, a))
end
vals
else
vals = zeros(length(agents))
probs = get_prob(agents[get_current_player(env)].policy, env)
actions = get_actions(env)
for (a, p) in zip(actions, probs)
if p > 0 #= ignore illegal action =#
vals .+= p .* expected_policy_values(agents, child(env, a))
end
end
vals
end
end
1 change: 1 addition & 0 deletions src/extensions/ReinforcementLearningBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Flux
using BSON
using AbstractTrees

RLBase.update!(p::RandomPolicy, x) = nothing

Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s)

Expand Down