From 1823170124ca54c5834b4417da95dac38fa2e564 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 25 Aug 2020 12:33:00 +0800 Subject: [PATCH 1/6] sync --- src/core/run.jl | 75 +++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 40 deletions(-) diff --git a/src/core/run.jl b/src/core/run.jl index c001f9d..897ffbd 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -79,59 +79,54 @@ function run( ) @assert length(agents) == get_num_players(env) + hooks = Dict(get_role(agent) => hook for (agent, hook) in zip(agents, hooks)) + agents = Dict(get_role(agent) => agent for agent in agents) reset!(env) - valid_action = rand(get_actions(env)) # init with a dummy value + + agent = agents[get_current_player(env)] + hook = hooks[get_current_player(env)] - # async here? - for (agent, hook) in zip(agents, hooks) - agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_ACT_STAGE, agent, env, action) - # for Sequential environments, only one action is valid - if get_current_player(env) == get_role(agent) - valid_action = action - end + for (A,H) in zip(values(agents), values(hooks)) + A(PRE_EPISODE_STAGE, env) + H(PRE_EPISODE_STAGE, A, env) end - while true - env(valid_action) + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) - for (agent, hook) in zip(agents, hooks) - agent(POST_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(POST_ACT_STAGE, agent, env) - end + while true + env(action) + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) if get_terminal(env) - for (agent, hook) in zip(agents, hooks) - agent(POST_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(POST_EPISODE_STAGE, agent, env) + for (A,H) in zip(values(agents), values(hooks)) + A(POST_EPISODE_STAGE, env) + H(POST_EPISODE_STAGE, A, env) end - stop_condition(agents, env) && break + stop_condition(agent, env) && break + reset!(env) - # async here? - for (agent, hook) in zip(agents, hooks) - agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_ACT_STAGE, agent, env, action) - # for Sequential environments, only one action is valid - if get_current_player(env) == get_role(agent) - valid_action = action - end + + for (A,H) in zip(values(agents), values(hooks)) + A(PRE_EPISODE_STAGE, env) + H(PRE_EPISODE_STAGE, A, env) end + + agent = agents[get_current_player(env)] + hook = hooks[get_current_player(env)] + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) else - stop_condition(agents, env) && break - for (agent, hook) in zip(agents, hooks) - action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) - hook(PRE_ACT_STAGE, agent, env, action) - # for Sequential environments, only one action is valid - if get_current_player(env) == get_role(agent) - valid_action = action - end - end + stop_condition(agent, env) && break + + agent = agents[get_current_player(env)] + hook = hooks[get_current_player(env)] + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) end end + hooks end From 2467ba7271906f0c8f1039ebd196e1ee90139fa2 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 26 Aug 2020 15:29:35 +0800 Subject: [PATCH 2/6] support MultiAgent with hooks --- src/core/hooks.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/core/hooks.jl b/src/core/hooks.jl index 933241c..8003c89 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -129,14 +129,24 @@ Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook reward::Float64 = 0.0 end -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env) +(hook::TotalRewardPerEpisode)(s::PostActStage, agent, env) = hook(s, agent, env, NumAgentStyle(env)) + +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::SingleAgent) hook.reward += get_reward(env) end -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv) +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::SingleAgent) hook.reward += get_reward(env.env) end +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::MultiAgent) + hook.reward += get_reward(env, get_role(agent)) +end + +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::MultiAgent) + hook.reward += get_reward(env.env, get_role(agent)) +end + function (hook::TotalRewardPerEpisode)( ::Union{PostEpisodeStage,PostExperimentStage}, agent, From 1d421c8e28a9932567d52a57c2bc0ba5f346a018 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 27 Aug 2020 22:56:06 +0800 Subject: [PATCH 3/6] fix CFR --- src/core/hooks.jl | 36 +++++++++++------------------ src/core/run.jl | 59 +++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 63 insertions(+), 32 deletions(-) diff --git a/src/core/hooks.jl b/src/core/hooks.jl index 8003c89..55e1d53 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -129,29 +129,19 @@ Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook reward::Float64 = 0.0 end -(hook::TotalRewardPerEpisode)(s::PostActStage, agent, env) = hook(s, agent, env, NumAgentStyle(env)) - -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::SingleAgent) - hook.reward += get_reward(env) -end - -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::SingleAgent) - hook.reward += get_reward(env.env) -end - -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::MultiAgent) - hook.reward += get_reward(env, get_role(agent)) -end - -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::MultiAgent) - hook.reward += get_reward(env.env, get_role(agent)) -end - -function (hook::TotalRewardPerEpisode)( - ::Union{PostEpisodeStage,PostExperimentStage}, - agent, - env, -) +(hook::TotalRewardPerEpisode)(s::AbstractStage, agent, env) = hook(s, agent, env, RewardStyle(env), NumAgentStyle(env)) +(hook::TotalRewardPerEpisode)(::AbstractStage, agent, env, ::Any, ::Any) = nothing + +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::TerminalReward, ::SingleAgent) = push!(hook.rewards, get_reward(env)) +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::TerminalReward, ::MultiAgent) = push!(hook.rewards, get_reward(env, get_role(agent))) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::SingleAgent) = hook.reward += get_reward(env) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::MultiAgent) = hook.reward += get_reward(env, get_role(agent)) +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env::RewardOverriddenEnv, ::TerminalReward, ::SingleAgent) = push!(hook.rewards, get_reward(env.env)) +(hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env::RewardOverriddenEnv, ::TerminalReward, ::MultiAgent) = push!(hook.rewards, get_reward(env.env, get_role(agent))) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::StepReward, ::SingleAgent) = hook.reward += get_reward(env.env) +(hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv, ::StepReward, ::MultiAgent) = hook.reward += get_reward(env.env, get_role(agent)) + +function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::StepReward, ::Any) push!(hook.rewards, hook.reward) hook.reward = 0 end diff --git a/src/core/run.jl b/src/core/run.jl index 897ffbd..2ea68d1 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -1,3 +1,5 @@ +export expected_policy_values + import Base: run run(agent, env::AbstractEnv, args...) = @@ -86,9 +88,9 @@ function run( agent = agents[get_current_player(env)] hook = hooks[get_current_player(env)] - for (A,H) in zip(values(agents), values(hooks)) - A(PRE_EPISODE_STAGE, env) - H(PRE_EPISODE_STAGE, A, env) + for p in get_players(env) + agents[p](PRE_EPISODE_STAGE, env) + hooks[p](PRE_EPISODE_STAGE, agents[p], env) end action = agent(PRE_ACT_STAGE, env) @@ -100,18 +102,18 @@ function run( hook(POST_ACT_STAGE, agent, env) if get_terminal(env) - for (A,H) in zip(values(agents), values(hooks)) - A(POST_EPISODE_STAGE, env) - H(POST_EPISODE_STAGE, A, env) + for p in get_players(env) + agents[p](POST_EPISODE_STAGE, env) + hooks[p](POST_EPISODE_STAGE, agents[p], env) end stop_condition(agent, env) && break reset!(env) - for (A,H) in zip(values(agents), values(hooks)) - A(PRE_EPISODE_STAGE, env) - H(PRE_EPISODE_STAGE, A, env) + for p in get_players(env) + agents[p](PRE_EPISODE_STAGE, env) + hooks[p](PRE_EPISODE_STAGE, agents[p], env) end agent = agents[get_current_player(env)] @@ -130,3 +132,42 @@ function run( hooks end + +""" + expected_policy_values(agents, env) + +Calculate the expected return of each agent. +""" +expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv) = expected_policy_values(Dict(get_role(agent) => agent for agent in agents), env) + +expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values(agents, env, RewardStyle(env), ChanceStyle(env), DynamicStyle(env)) + +function expected_policy_values(agents::Dict, env::AbstractEnv, ::TerminalReward, ::Union{ExplicitStochastic,Deterministic}, ::Sequential) + if get_terminal(env) + [get_reward(env, get_role(agent)) for agent in values(agents)] + elseif get_current_player(env) == get_chance_player(env) + vals = zeros(length(agents)) + for a::ActionProbPair in get_legal_actions(env) + vals .+= a.prob .* expected_policy_values(agents, child(env, a)) + end + vals + else + vals = zeros(length(agents)) + probs = get_prob(agents[get_current_player(env)].policy, env) + actions = get_actions(env) + for (a, p) in zip(actions, probs) + if p > 0 #= ignore illegal action =# + vals .+= p .* expected_policy_values(agents, child(env, a)) + end + end + vals + end +end + +function test_f(ps, env) + reset!(env) + while !get_terminal(env) + env |> ps[get_current_player(env)] |> env + end + [get_reward(env, p) for p in get_players(env) if p != get_chance_player(env)] +end \ No newline at end of file From 67b85cbf5b6ca89949b5532927e0efa9813b84e9 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Sep 2020 11:54:06 +0800 Subject: [PATCH 4/6] remove dirty code --- src/core/run.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/core/run.jl b/src/core/run.jl index 2ea68d1..ca8e3b3 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -163,11 +163,3 @@ function expected_policy_values(agents::Dict, env::AbstractEnv, ::TerminalReward vals end end - -function test_f(ps, env) - reset!(env) - while !get_terminal(env) - env |> ps[get_current_player(env)] |> env - end - [get_reward(env, p) for p in get_players(env) if p != get_chance_player(env)] -end \ No newline at end of file From af9047bac9b2057e81b08fd9e6d152967a0446ad Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Sep 2020 12:14:33 +0800 Subject: [PATCH 5/6] fix bug in TabularLearner --- src/components/learners/tabular_learner.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/components/learners/tabular_learner.jl b/src/components/learners/tabular_learner.jl index 9d5faa4..3b166ce 100644 --- a/src/components/learners/tabular_learner.jl +++ b/src/components/learners/tabular_learner.jl @@ -10,8 +10,8 @@ struct TabularLearner{S,T} <: AbstractPolicy end TabularLearner() = TabularLearner{Int,Float32}() -TabularLearner{S}() = TabularLearner{S,Float32}() -TabularLearner{S,T}() = TabularLearner(Dict{S,Vector{T}}()) +TabularLearner{S}() where S = TabularLearner{S,Float32}() +TabularLearner{S,T}() where {S,T} = TabularLearner(Dict{S,Vector{T}}()) function (p::TabularLearner)(env::AbstractEnv) s = get_state(env) From ff312201fd515d7489d9740582871188108d9c51 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Sep 2020 12:52:22 +0800 Subject: [PATCH 6/6] bugfix --- src/components/learners/tabular_learner.jl | 2 +- src/extensions/ReinforcementLearningBase.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/components/learners/tabular_learner.jl b/src/components/learners/tabular_learner.jl index 3b166ce..453211d 100644 --- a/src/components/learners/tabular_learner.jl +++ b/src/components/learners/tabular_learner.jl @@ -30,5 +30,5 @@ function (p::TabularLearner)(env::AbstractEnv) end end -update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience) +RLBase.update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience) diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index 3294940..6032434 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -5,6 +5,7 @@ using Flux using BSON using AbstractTrees +RLBase.update!(p::RandomPolicy, x) = nothing Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s)