From fa7b51e4ea098dbbfcd7b64980e57a3b32efab2c Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Sep 2020 11:40:43 +0800 Subject: [PATCH 1/3] add TabularLearner --- src/components/learners/abstract_learner.jl | 21 +++++++++++++ src/components/learners/learners.jl | 22 ++----------- src/components/learners/tabular_learner.jl | 34 +++++++++++++++++++++ 3 files changed, 57 insertions(+), 20 deletions(-) create mode 100644 src/components/learners/abstract_learner.jl create mode 100644 src/components/learners/tabular_learner.jl diff --git a/src/components/learners/abstract_learner.jl b/src/components/learners/abstract_learner.jl new file mode 100644 index 0000000..618f07f --- /dev/null +++ b/src/components/learners/abstract_learner.jl @@ -0,0 +1,21 @@ +export AbstractLearner, extract_experience + +using Flux + +""" + (learner::AbstractLearner)(env) + +A learner is usually used to estimate state values, state-action values or distributional values based on experiences. +""" +abstract type AbstractLearner end + +function (learner::AbstractLearner)(env) end + +""" + get_priority(p::AbstractLearner, experience) +""" +function RLBase.get_priority(p::AbstractLearner, experience) end + +# TODO: deprecate this default function +Flux.testmode!(learner::AbstractLearner, mode = true) = + Flux.testmode!(learner.approximator, mode) diff --git a/src/components/learners/learners.jl b/src/components/learners/learners.jl index 4071182..eabeff5 100644 --- a/src/components/learners/learners.jl +++ b/src/components/learners/learners.jl @@ -1,20 +1,2 @@ -export AbstractLearner, extract_experience - -using Flux - -""" - (learner::AbstractLearner)(env) - -A learner is usually used to estimate state values, state-action values or distributional values based on experiences. -""" -abstract type AbstractLearner end - -function (learner::AbstractLearner)(env) end - -""" - get_priority(p::AbstractLearner, experience) -""" -function RLBase.get_priority(p::AbstractLearner, experience) end - -Flux.testmode!(learner::AbstractLearner, mode = true) = - Flux.testmode!(learner.approximator, mode) +include("abstract_learner.jl") +include("tabular_learner.jl") \ No newline at end of file diff --git a/src/components/learners/tabular_learner.jl b/src/components/learners/tabular_learner.jl new file mode 100644 index 0000000..9d5faa4 --- /dev/null +++ b/src/components/learners/tabular_learner.jl @@ -0,0 +1,34 @@ +export TabularLearner + +""" + TabularLearner{S, T} + +Use a `Dict{S,Vector{T}}` to store action probabilities. +""" +struct TabularLearner{S,T} <: AbstractPolicy + table::Dict{S,Vector{T}} +end + +TabularLearner() = TabularLearner{Int,Float32}() +TabularLearner{S}() = TabularLearner{S,Float32}() +TabularLearner{S,T}() = TabularLearner(Dict{S,Vector{T}}()) + +function (p::TabularLearner)(env::AbstractEnv) + s = get_state(env) + if haskey(p.table, s) + p.table[s] + elseif ActionStyle(env) === FULL_ACTION_SET + mask = get_legal_actions_mask(env) + prob = mask ./ sum(mask) + p.table[s] = prob + prob + elseif ActionStyle(env) === MINIMAL_ACTION_SET + n = length(get_actions(env)) + prob = fill(1 / n, n) + p.table[s] = prob + prob + end +end + +update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience) + From 903463b235ac8e6a4bc269606f255c5496cd6b1d Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Sep 2020 11:42:15 +0800 Subject: [PATCH 2/3] remove extract_experience in learner --- src/components/learners/abstract_learner.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/components/learners/abstract_learner.jl b/src/components/learners/abstract_learner.jl index 618f07f..aa88c81 100644 --- a/src/components/learners/abstract_learner.jl +++ b/src/components/learners/abstract_learner.jl @@ -1,4 +1,4 @@ -export AbstractLearner, extract_experience +export AbstractLearner using Flux From 0dec0739ea5540ced1d5c816d1f19d3ba552b169 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 1 Sep 2020 11:44:13 +0800 Subject: [PATCH 3/3] remove extract_experience in DynaAgent --- src/components/agents/dyna_agent.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/components/agents/dyna_agent.jl b/src/components/agents/dyna_agent.jl index 639d703..57f3f2a 100644 --- a/src/components/agents/dyna_agent.jl +++ b/src/components/agents/dyna_agent.jl @@ -65,7 +65,7 @@ end RLBase.update!(model::AbstractEnvironmentModel, t::AbstractTrajectory, π::AbstractPolicy) = update!(model, t) -function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory) - transitions = extract_experience(buffer, model) - isnothing(transitions) || update!(model, transitions) -end +# function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory) +# transitions = extract_experience(buffer, model) +# isnothing(transitions) || update!(model, transitions) +# end