diff --git a/src/components/explorers/weighted_explorer.jl b/src/components/explorers/weighted_explorer.jl index accfbaa..fe260ff 100644 --- a/src/components/explorers/weighted_explorer.jl +++ b/src/components/explorers/weighted_explorer.jl @@ -32,3 +32,14 @@ function (s::WeightedExplorer)(values, mask) values[.!mask] .= 0 s(values) end + +RLBase.get_prob(s::WeightedExplorer{true}, values) = values +RLBase.get_prob(s::WeightedExplorer{false}, values) = values ./ sum(values) + +# assume `values` and `mask` matches and `sum(values) == 1` +RLBase.get_prob(s::WeightedExplorer{true}, values, mask) = values + +function RLBase.get_prob(s::WeightedExplorer{false}, values, mask) + s = sum(@view(values[mask])) + map((v,m) -> m ? v/s : zero(v), values, mask) +end diff --git a/src/components/learners/tabular_learner.jl b/src/components/learners/tabular_learner.jl index 0e25958..f38dfcc 100644 --- a/src/components/learners/tabular_learner.jl +++ b/src/components/learners/tabular_learner.jl @@ -5,7 +5,7 @@ export TabularLearner Use a `Dict{S,Vector{T}}` to store action probabilities. """ -struct TabularLearner{S,T} <: AbstractPolicy +struct TabularLearner{S,T} <: AbstractLearner table::Dict{S,Vector{T}} end diff --git a/src/components/policies/Q_based_policy.jl b/src/components/policies/Q_based_policy.jl index 45f6a06..6471883 100644 --- a/src/components/policies/Q_based_policy.jl +++ b/src/components/policies/Q_based_policy.jl @@ -19,9 +19,8 @@ end Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learner (π::QBasedPolicy)(env) = π(env, ActionStyle(env)) -(π::QBasedPolicy)(env, ::MinimalActionSet) = env |> π.learner |> π.explorer -(π::QBasedPolicy)(env, ::FullActionSet) = - π.explorer(π.learner(env), get_legal_actions_mask(env)) +(π::QBasedPolicy)(env, ::MinimalActionSet) = get_actions(env)[env |> π.learner |> π.explorer] +(π::QBasedPolicy)(env, ::FullActionSet) = get_actions(env)[π.explorer(π.learner(env), get_legal_actions_mask(env))] RLBase.get_prob(p::QBasedPolicy, env) = get_prob(p, env, ActionStyle(env)) RLBase.get_prob(p::QBasedPolicy, env, ::MinimalActionSet) = diff --git a/src/components/processors.jl b/src/components/processors.jl index cc071f1..dc030b5 100644 --- a/src/components/processors.jl +++ b/src/components/processors.jl @@ -53,7 +53,7 @@ function Base.push!( push!(cb, select_last_frame(stacked_data)) end -function RLBase.reset!(p::StackFrames{T}) where {T} +function RLBase.reset!(p::StackFrames{T,N}) where {T,N} empty!(p.buffer) for _ in 1:capacity(p.buffer) push!(p.buffer, zeros(T, size(p.buffer)[1:N-1])) diff --git a/src/core/run.jl b/src/core/run.jl index 42f8059..034b6bd 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -138,26 +138,17 @@ end Calculate the expected return of each agent. """ -expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv) = - expected_policy_values(Dict(get_role(agent) => agent for agent in agents), env) - -expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values( - agents, - env, - RewardStyle(env), - ChanceStyle(env), - DynamicStyle(env), -) +function expected_policy_values(agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv) + agents = Dict(get_role(agent) => agent for agent in agents) + values = expected_policy_values(agents, env) + Dict(zip(get_players(env), values)) +end -function expected_policy_values( - agents::Dict, - env::AbstractEnv, - ::TerminalReward, - ::Union{ExplicitStochastic,Deterministic}, - ::Sequential, -) +expected_policy_values(agents::Dict, env::AbstractEnv) = expected_policy_values(agents, env, RewardStyle(env), ChanceStyle(env), DynamicStyle(env)) + +function expected_policy_values(agents::Dict, env::AbstractEnv, ::TerminalReward, ::Union{ExplicitStochastic,Deterministic}, ::Sequential) if get_terminal(env) - [get_reward(env, get_role(agent)) for agent in values(agents)] + [get_reward(env, get_role(agents[p])) for p in get_players(env)] elseif get_current_player(env) == get_chance_player(env) vals = zeros(length(agents)) for a::ActionProbPair in get_legal_actions(env)