From 17569c88a44322cd3550d4803351e657d70d3a2a Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 9 Oct 2020 01:01:18 +0800 Subject: [PATCH] improve tabular learner --- Project.toml | 2 +- src/components/learners/tabular_learner.jl | 28 ++++++------- src/components/policies/Q_based_policy.jl | 32 ++++++++++++++- src/core/run.jl | 47 ---------------------- src/utils/printing.jl | 2 +- 5 files changed, 47 insertions(+), 64 deletions(-) diff --git a/Project.toml b/Project.toml index a5a29af..684957a 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ AbstractTrees = "0.3" Adapt = "2" BSON = "0.2" CUDA = "1" -Distributions = "0.22, 0.23, 0.24" +Distributions = "0.24" ElasticArrays = "1.2" FillArrays = "0.8, 0.9" Flux = "0.11.1" diff --git a/src/components/learners/tabular_learner.jl b/src/components/learners/tabular_learner.jl index f38dfcc..f7d1aa3 100644 --- a/src/components/learners/tabular_learner.jl +++ b/src/components/learners/tabular_learner.jl @@ -13,20 +13,20 @@ TabularLearner() = TabularLearner{Int,Float32}() TabularLearner{S}() where {S} = TabularLearner{S,Float32}() TabularLearner{S,T}() where {S,T} = TabularLearner(Dict{S,Vector{T}}()) -function (p::TabularLearner)(env::AbstractEnv) - s = get_state(env) - if haskey(p.table, s) - p.table[s] - elseif ActionStyle(env) === FULL_ACTION_SET - mask = get_legal_actions_mask(env) - prob = mask ./ sum(mask) - p.table[s] = prob - prob - elseif ActionStyle(env) === MINIMAL_ACTION_SET - n = length(get_actions(env)) - prob = fill(1 / n, n) - p.table[s] = prob - prob +(p::TabularLearner)(env::AbstractEnv) = p(ChanceStyle(env), env) + +function (p::TabularLearner)(::ExplicitStochastic, env::AbstractEnv) + if get_current_player(env) == get_chance_player(env) + [a.prob for a::ActionProbPair in get_actions(env)] + else + p(DETERMINISTIC, env) # treat it just like a normal one + end +end + +function (p::TabularLearner)(::RLBase.AbstractChanceStyle, env::AbstractEnv) + get!(p.table, get_state(env)) do + n = length(get_legal_actions(env)) + fill(1/n, n) end end diff --git a/src/components/policies/Q_based_policy.jl b/src/components/policies/Q_based_policy.jl index 1770842..c6753ea 100644 --- a/src/components/policies/Q_based_policy.jl +++ b/src/components/policies/Q_based_policy.jl @@ -1,4 +1,4 @@ -export QBasedPolicy +export QBasedPolicy, TabularRandomPolicy using MacroTools: @forward using Flux @@ -35,3 +35,33 @@ function Flux.testmode!(p::QBasedPolicy, mode = true) testmode!(p.learner, mode) testmode!(p.explorer, mode) end + +##### +# TabularRandomPolicy +##### + +const TabularRandomPolicy = QBasedPolicy{<:TabularLearner, <:WeightedExplorer} + +function TabularRandomPolicy(;rng=Random.GLOBAL_RNG, is_normalized=true, table=Dict{String,Vector{Float64}}()) + QBasedPolicy(; + learner = TabularLearner(table), + explorer = WeightedExplorer(; is_normalized = is_normalized, rng = rng), + ) +end + +function (p::TabularRandomPolicy)(env::AbstractEnv) + if ChanceStyle(env) === EXPLICIT_STOCHASTIC + if get_current_player(env) == get_chance_player(env) + # this should be faster. we don't need to allocate memory to store the probability of chance node + return rand(p.explorer.rng, get_actions(env)) + end + end + p(env, ActionStyle(env)) # fall back to general implementation above +end + +function RLBase.get_prob(p::TabularRandomPolicy, env, ::FullActionSet) + m = get_legal_actions_mask(env) + prob = zeros(length(m)) + prob[m] .= p.learner(env) + prob +end \ No newline at end of file diff --git a/src/core/run.jl b/src/core/run.jl index 6cff631..ea46452 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -132,50 +132,3 @@ function run( hooks end - -""" - expected_policy_values(agents, env) - -Calculate the expected return of each agent. -""" -function expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv) - agents = Dict(get_role(agent) => agent for agent in agents) - values = expected_policy_values(agents, env) - Dict(zip(get_players(env), values)) -end - -expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values( - agents, - env, - RewardStyle(env), - ChanceStyle(env), - DynamicStyle(env), -) - -function expected_policy_values( - agents::Dict, - env::AbstractEnv, - ::TerminalReward, - ::Union{ExplicitStochastic,Deterministic}, - ::Sequential, -) - if get_terminal(env) - [get_reward(env, get_role(agents[p])) for p in get_players(env)] - elseif get_current_player(env) == get_chance_player(env) - vals = zeros(length(agents)) - for a::ActionProbPair in get_legal_actions(env) - vals .+= a.prob .* expected_policy_values(agents, child(env, a)) - end - vals - else - vals = zeros(length(agents)) - probs = get_prob(agents[get_current_player(env)].policy, env) - actions = get_actions(env) - for (a, p) in zip(actions, probs) - if p > 0 #= ignore illegal action =# - vals .+= p .* expected_policy_values(agents, child(env, a)) - end - end - vals - end -end diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 81690f8..90190d4 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -14,7 +14,7 @@ AT.children(t::StructTree{X}) where {X} = Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X)) AT.children( t::StructTree{T}, -) where {T<:Union{AbstractArray,MersenneTwister,ProgressMeter.Progress,Function}} = () +) where {T<:Union{AbstractArray, AbstractDict, MersenneTwister,ProgressMeter.Progress,Function}} = () AT.children(t::Pair{Symbol,<:StructTree}) = children(last(t)) AT.children(t::StructTree{UnionAll}) = ()