diff --git a/src/components/learners/tabular_learner.jl b/src/components/learners/tabular_learner.jl index 9d5faa4..453211d 100644 --- a/src/components/learners/tabular_learner.jl +++ b/src/components/learners/tabular_learner.jl @@ -10,8 +10,8 @@ struct TabularLearner{S,T} <: AbstractPolicy end TabularLearner() = TabularLearner{Int,Float32}() -TabularLearner{S}() = TabularLearner{S,Float32}() -TabularLearner{S,T}() = TabularLearner(Dict{S,Vector{T}}()) +TabularLearner{S}() where S = TabularLearner{S,Float32}() +TabularLearner{S,T}() where {S,T} = TabularLearner(Dict{S,Vector{T}}()) function (p::TabularLearner)(env::AbstractEnv) s = get_state(env) @@ -30,5 +30,5 @@ function (p::TabularLearner)(env::AbstractEnv) end end -update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience) +RLBase.update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience) diff --git a/src/core/hooks.jl b/src/core/hooks.jl index 933241c..55e1d53 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -129,19 +129,19 @@ Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook reward::Float64 = 0.0 end -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env) - hook.reward += get_reward(env) -end - -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv) - hook.reward += get_reward(env.env) -end - -function (hook::TotalRewardPerEpisode)( - ::Union{PostEpisodeStage,PostExperimentStage}, - agent, - env, -) +(hook::TotalRewardPerEpisode)(s::AbstractStage, agent, env) = hook(s, agent, env, RewardStyle(env), NumAgentStyle(env)) +(hook::TotalRewardPerEpisode)(::AbstractStage, agent, env, ::Any, ::Any) = nothing + +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::TerminalReward, ::SingleAgent) = push!(hook.rewards, get_reward(env)) +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::TerminalReward, ::MultiAgent) = push!(hook.rewards, get_reward(env, get_role(agent))) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::SingleAgent) = hook.reward += get_reward(env) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::MultiAgent) = hook.reward += get_reward(env, get_role(agent)) +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env::RewardOverriddenEnv, ::TerminalReward, ::SingleAgent) = push!(hook.rewards, get_reward(env.env)) +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env::RewardOverriddenEnv, ::TerminalReward, ::MultiAgent) = push!(hook.rewards, get_reward(env.env, get_role(agent))) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::StepReward, ::SingleAgent) = hook.reward += get_reward(env.env) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::StepReward, ::MultiAgent) = hook.reward += get_reward(env.env, get_role(agent)) + +function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::StepReward, ::Any) push!(hook.rewards, hook.reward) hook.reward = 0 end diff --git a/src/core/run.jl b/src/core/run.jl index c001f9d..ca8e3b3 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -1,3 +1,5 @@ +export expected_policy_values + import Base: run run(agent, env::AbstractEnv, args...) = @@ -79,59 +81,85 @@ function run( ) @assert length(agents) == get_num_players(env) + hooks = Dict(get_role(agent) => hook for (agent, hook) in zip(agents, hooks)) + agents = Dict(get_role(agent) => agent for agent in agents) reset!(env) - valid_action = rand(get_actions(env)) # init with a dummy value + + agent = agents[get_current_player(env)] + hook = hooks[get_current_player(env)] - # async here? - for (agent, hook) in zip(agents, hooks) - agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_ACT_STAGE, agent, env, action) - # for Sequential environments, only one action is valid - if get_current_player(env) == get_role(agent) - valid_action = action - end + for p in get_players(env) + agents[p](PRE_EPISODE_STAGE, env) + hooks[p](PRE_EPISODE_STAGE, agents[p], env) end - while true - env(valid_action) + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) - for (agent, hook) in zip(agents, hooks) - agent(POST_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(POST_ACT_STAGE, agent, env) - end + while true + env(action) + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) if get_terminal(env) - for (agent, hook) in zip(agents, hooks) - agent(POST_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(POST_EPISODE_STAGE, agent, env) + for p in get_players(env) + agents[p](POST_EPISODE_STAGE, env) + hooks[p](POST_EPISODE_STAGE, agents[p], env) end - stop_condition(agents, env) && break + stop_condition(agent, env) && break + reset!(env) - # async here? - for (agent, hook) in zip(agents, hooks) - agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_ACT_STAGE, agent, env, action) - # for Sequential environments, only one action is valid - if get_current_player(env) == get_role(agent) - valid_action = action - end + + for p in get_players(env) + agents[p](PRE_EPISODE_STAGE, env) + hooks[p](PRE_EPISODE_STAGE, agents[p], env) end + + agent = agents[get_current_player(env)] + hook = hooks[get_current_player(env)] + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) else - stop_condition(agents, env) && break - for (agent, hook) in zip(agents, hooks) - action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_ACT_STAGE, agent, env, action) - # for Sequential environments, only one action is valid - if get_current_player(env) == get_role(agent) - valid_action = action - end - end + stop_condition(agent, env) && break + + agent = agents[get_current_player(env)] + hook = hooks[get_current_player(env)] + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) end end + hooks end + +""" + expected_policy_values(agents, env) + +Calculate the expected return of each agent. +""" +expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv) = expected_policy_values(Dict(get_role(agent) => agent for agent in agents), env) + +expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values(agents, env, RewardStyle(env), ChanceStyle(env), DynamicStyle(env)) + +function expected_policy_values(agents::Dict, env::AbstractEnv, ::TerminalReward, ::Union{ExplicitStochastic,Deterministic}, ::Sequential) + if get_terminal(env) + [get_reward(env, get_role(agent)) for agent in values(agents)] + elseif get_current_player(env) == get_chance_player(env) + vals = zeros(length(agents)) + for a::ActionProbPair in get_legal_actions(env) + vals .+= a.prob .* expected_policy_values(agents, child(env, a)) + end + vals + else + vals = zeros(length(agents)) + probs = get_prob(agents[get_current_player(env)].policy, env) + actions = get_actions(env) + for (a, p) in zip(actions, probs) + if p > 0 #= ignore illegal action =# + vals .+= p .* expected_policy_values(agents, child(env, a)) + end + end + vals + end +end diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index 3294940..6032434 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -5,6 +5,7 @@ using Flux using BSON using AbstractTrees +RLBase.update!(p::RandomPolicy, x) = nothing Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s)