Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/components/agents/dyna_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
RLBase.update!(model::AbstractEnvironmentModel, t::AbstractTrajectory, π::AbstractPolicy) =
update!(model, t)

function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory)
transitions = extract_experience(buffer, model)
isnothing(transitions) || update!(model, transitions)
end
# function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory)
# transitions = extract_experience(buffer, model)
# isnothing(transitions) || update!(model, transitions)
# end
21 changes: 21 additions & 0 deletions src/components/learners/abstract_learner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
export AbstractLearner

using Flux

"""
(learner::AbstractLearner)(env)

A learner is usually used to estimate state values, state-action values or distributional values based on experiences.
"""
abstract type AbstractLearner end

function (learner::AbstractLearner)(env) end

"""
get_priority(p::AbstractLearner, experience)
"""
function RLBase.get_priority(p::AbstractLearner, experience) end

# TODO: deprecate this default function
Flux.testmode!(learner::AbstractLearner, mode = true) =
Flux.testmode!(learner.approximator, mode)
22 changes: 2 additions & 20 deletions src/components/learners/learners.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,2 @@
export AbstractLearner, extract_experience

using Flux

"""
(learner::AbstractLearner)(env)

A learner is usually used to estimate state values, state-action values or distributional values based on experiences.
"""
abstract type AbstractLearner end

function (learner::AbstractLearner)(env) end

"""
get_priority(p::AbstractLearner, experience)
"""
function RLBase.get_priority(p::AbstractLearner, experience) end

Flux.testmode!(learner::AbstractLearner, mode = true) =
Flux.testmode!(learner.approximator, mode)
include("abstract_learner.jl")
include("tabular_learner.jl")
34 changes: 34 additions & 0 deletions src/components/learners/tabular_learner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
export TabularLearner

"""
TabularLearner{S, T}

Use a `Dict{S,Vector{T}}` to store action probabilities.
"""
struct TabularLearner{S,T} <: AbstractPolicy
table::Dict{S,Vector{T}}
end

TabularLearner() = TabularLearner{Int,Float32}()
TabularLearner{S}() = TabularLearner{S,Float32}()
TabularLearner{S,T}() = TabularLearner(Dict{S,Vector{T}}())

function (p::TabularLearner)(env::AbstractEnv)
s = get_state(env)
if haskey(p.table, s)
p.table[s]
elseif ActionStyle(env) === FULL_ACTION_SET
mask = get_legal_actions_mask(env)
prob = mask ./ sum(mask)
p.table[s] = prob
prob
elseif ActionStyle(env) === MINIMAL_ACTION_SET
n = length(get_actions(env))
prob = fill(1 / n, n)
p.table[s] = prob
prob
end
end

update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience)