diff --git a/src/policies/q_based_policies/learners/abstract_learner.jl b/src/policies/q_based_policies/learners/abstract_learner.jl index 940b69d..a3f27f3 100644 --- a/src/policies/q_based_policies/learners/abstract_learner.jl +++ b/src/policies/q_based_policies/learners/abstract_learner.jl @@ -20,9 +20,18 @@ Base.show(io::IO, p::AbstractLearner) = AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10)) function RLBase.update!( - p::AbstractLearner, + L::AbstractLearner, t::AbstractTrajectory, e::AbstractEnv, s::AbstractStage ) +end + +function RLBase.update!( + L::AbstractLearner, + t::AbstractTrajectory, + e::AbstractEnv, + s::PreActStage +) + update!(L, t) end \ No newline at end of file diff --git a/src/policies/tabular_random_policy.jl b/src/policies/tabular_random_policy.jl index 3ee17b8..a802226 100644 --- a/src/policies/tabular_random_policy.jl +++ b/src/policies/tabular_random_policy.jl @@ -6,7 +6,7 @@ export TabularRandomPolicy Use a `Dict` to store action distribution. """ Base.@kwdef struct TabularRandomPolicy{S,T, R} <: AbstractPolicy - table::Dict{S,T} = Dict{Int, Float32}() + table::Dict{S,T} = Dict{Any, Vector{Float32}}() rng::R = Random.GLOBAL_RNG end @@ -19,7 +19,7 @@ function RLBase.prob(p::TabularRandomPolicy, ::ExplicitStochastic, env::Abstract if current_player(env) == chance_player(env) prob(env) else - p(DETERMINISTIC, env) # treat it just like a normal one + prob(p, DETERMINISTIC, env) # treat it just like a normal one end end @@ -41,11 +41,23 @@ function RLBase.prob(t::TabularRandomPolicy, ::MinimalActionSet, env::AbstractEn end end -function RLBase.prob(t::TabularRandomPolicy, env::AbstractEnv, action::Int) - prob(t, state(env))[action] +function RLBase.prob(t::TabularRandomPolicy, env::AbstractEnv, action) + prob(t, env, action_space(env), action) end -function RLBase.prob(t::TabularRandomPolicy{S}, state::S, action::Int) where S +function RLBase.prob(t::TabularRandomPolicy, env::AbstractEnv, action_space, action) + prob(t, env)[findfirst(==(action), action_space)] +end + +function RLBase.prob(t::TabularRandomPolicy, env::AbstractEnv, action_space::Base.OneTo, action) + prob(t, env)[action] +end + +# function RLBase.prob(t::TabularRandomPolicy, env::AbstractEnv, action_space::ZeroTo, action) +# prob(t, env)[action+1] +# end + +function RLBase.prob(t::TabularRandomPolicy, state, action) # assume table is already initialized t.table[state][action] end