Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ AbstractTrees = "0.3"
Adapt = "2"
BSON = "0.2"
CUDA = "1"
Distributions = "0.22, 0.23, 0.24"
Distributions = "0.24"
ElasticArrays = "1.2"
FillArrays = "0.8, 0.9"
Flux = "0.11.1"
Expand Down
28 changes: 14 additions & 14 deletions src/components/learners/tabular_learner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ TabularLearner() = TabularLearner{Int,Float32}()
TabularLearner{S}() where {S} = TabularLearner{S,Float32}()
TabularLearner{S,T}() where {S,T} = TabularLearner(Dict{S,Vector{T}}())

function (p::TabularLearner)(env::AbstractEnv)
s = get_state(env)
if haskey(p.table, s)
p.table[s]
elseif ActionStyle(env) === FULL_ACTION_SET
mask = get_legal_actions_mask(env)
prob = mask ./ sum(mask)
p.table[s] = prob
prob
elseif ActionStyle(env) === MINIMAL_ACTION_SET
n = length(get_actions(env))
prob = fill(1 / n, n)
p.table[s] = prob
prob
(p::TabularLearner)(env::AbstractEnv) = p(ChanceStyle(env), env)

function (p::TabularLearner)(::ExplicitStochastic, env::AbstractEnv)
if get_current_player(env) == get_chance_player(env)
[a.prob for a::ActionProbPair in get_actions(env)]
else
p(DETERMINISTIC, env) # treat it just like a normal one
end
end

function (p::TabularLearner)(::RLBase.AbstractChanceStyle, env::AbstractEnv)
get!(p.table, get_state(env)) do
n = length(get_legal_actions(env))
fill(1/n, n)
end
end

Expand Down
32 changes: 31 additions & 1 deletion src/components/policies/Q_based_policy.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export QBasedPolicy
export QBasedPolicy, TabularRandomPolicy

using MacroTools: @forward
using Flux
Expand Down Expand Up @@ -35,3 +35,33 @@ function Flux.testmode!(p::QBasedPolicy, mode = true)
testmode!(p.learner, mode)
testmode!(p.explorer, mode)
end

#####
# TabularRandomPolicy
#####

const TabularRandomPolicy = QBasedPolicy{<:TabularLearner, <:WeightedExplorer}

function TabularRandomPolicy(;rng=Random.GLOBAL_RNG, is_normalized=true, table=Dict{String,Vector{Float64}}())
QBasedPolicy(;
learner = TabularLearner(table),
explorer = WeightedExplorer(; is_normalized = is_normalized, rng = rng),
)
end

function (p::TabularRandomPolicy)(env::AbstractEnv)
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
if get_current_player(env) == get_chance_player(env)
# this should be faster. we don't need to allocate memory to store the probability of chance node
return rand(p.explorer.rng, get_actions(env))
end
end
p(env, ActionStyle(env)) # fall back to general implementation above
end

function RLBase.get_prob(p::TabularRandomPolicy, env, ::FullActionSet)
m = get_legal_actions_mask(env)
prob = zeros(length(m))
prob[m] .= p.learner(env)
prob
end
47 changes: 0 additions & 47 deletions src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,50 +132,3 @@ function run(

hooks
end

"""
expected_policy_values(agents, env)

Calculate the expected return of each agent.
"""
function expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv)
agents = Dict(get_role(agent) => agent for agent in agents)
values = expected_policy_values(agents, env)
Dict(zip(get_players(env), values))
end

expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values(
agents,
env,
RewardStyle(env),
ChanceStyle(env),
DynamicStyle(env),
)

function expected_policy_values(
agents::Dict,
env::AbstractEnv,
::TerminalReward,
::Union{ExplicitStochastic,Deterministic},
::Sequential,
)
if get_terminal(env)
[get_reward(env, get_role(agents[p])) for p in get_players(env)]
elseif get_current_player(env) == get_chance_player(env)
vals = zeros(length(agents))
for a::ActionProbPair in get_legal_actions(env)
vals .+= a.prob .* expected_policy_values(agents, child(env, a))
end
vals
else
vals = zeros(length(agents))
probs = get_prob(agents[get_current_player(env)].policy, env)
actions = get_actions(env)
for (a, p) in zip(actions, probs)
if p > 0 #= ignore illegal action =#
vals .+= p .* expected_policy_values(agents, child(env, a))
end
end
vals
end
end
2 changes: 1 addition & 1 deletion src/utils/printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ AT.children(t::StructTree{X}) where {X} =
Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X))
AT.children(
t::StructTree{T},
) where {T<:Union{AbstractArray,MersenneTwister,ProgressMeter.Progress,Function}} = ()
) where {T<:Union{AbstractArray, AbstractDict, MersenneTwister,ProgressMeter.Progress,Function}} = ()
AT.children(t::Pair{Symbol,<:StructTree}) = children(last(t))
AT.children(t::StructTree{UnionAll}) = ()

Expand Down