Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 450d430

Browse files
authored
add TabularLearner (#115)
* add TabularLearner * remove extract_experience in learner * remove extract_experience in DynaAgent
1 parent 2c1d72e commit 450d430

File tree

4 files changed

+61
-24
lines changed

4 files changed

+61
-24
lines changed

src/components/agents/dyna_agent.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565
RLBase.update!(model::AbstractEnvironmentModel, t::AbstractTrajectory, π::AbstractPolicy) =
6666
update!(model, t)
6767

68-
function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory)
69-
transitions = extract_experience(buffer, model)
70-
isnothing(transitions) || update!(model, transitions)
71-
end
68+
# function RLBase.update!(model::AbstractEnvironmentModel, buffer::AbstractTrajectory)
69+
# transitions = extract_experience(buffer, model)
70+
# isnothing(transitions) || update!(model, transitions)
71+
# end
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
export AbstractLearner
2+
3+
using Flux
4+
5+
"""
6+
(learner::AbstractLearner)(env)
7+
8+
A learner is usually used to estimate state values, state-action values or distributional values based on experiences.
9+
"""
10+
abstract type AbstractLearner end
11+
12+
function (learner::AbstractLearner)(env) end
13+
14+
"""
15+
get_priority(p::AbstractLearner, experience)
16+
"""
17+
function RLBase.get_priority(p::AbstractLearner, experience) end
18+
19+
# TODO: deprecate this default function
20+
Flux.testmode!(learner::AbstractLearner, mode = true) =
21+
Flux.testmode!(learner.approximator, mode)
Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,2 @@
1-
export AbstractLearner, extract_experience
2-
3-
using Flux
4-
5-
"""
6-
(learner::AbstractLearner)(env)
7-
8-
A learner is usually used to estimate state values, state-action values or distributional values based on experiences.
9-
"""
10-
abstract type AbstractLearner end
11-
12-
function (learner::AbstractLearner)(env) end
13-
14-
"""
15-
get_priority(p::AbstractLearner, experience)
16-
"""
17-
function RLBase.get_priority(p::AbstractLearner, experience) end
18-
19-
Flux.testmode!(learner::AbstractLearner, mode = true) =
20-
Flux.testmode!(learner.approximator, mode)
1+
include("abstract_learner.jl")
2+
include("tabular_learner.jl")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
export TabularLearner
2+
3+
"""
4+
TabularLearner{S, T}
5+
6+
Use a `Dict{S,Vector{T}}` to store action probabilities.
7+
"""
8+
struct TabularLearner{S,T} <: AbstractPolicy
9+
table::Dict{S,Vector{T}}
10+
end
11+
12+
TabularLearner() = TabularLearner{Int,Float32}()
13+
TabularLearner{S}() = TabularLearner{S,Float32}()
14+
TabularLearner{S,T}() = TabularLearner(Dict{S,Vector{T}}())
15+
16+
function (p::TabularLearner)(env::AbstractEnv)
17+
s = get_state(env)
18+
if haskey(p.table, s)
19+
p.table[s]
20+
elseif ActionStyle(env) === FULL_ACTION_SET
21+
mask = get_legal_actions_mask(env)
22+
prob = mask ./ sum(mask)
23+
p.table[s] = prob
24+
prob
25+
elseif ActionStyle(env) === MINIMAL_ACTION_SET
26+
n = length(get_actions(env))
27+
prob = fill(1 / n, n)
28+
p.table[s] = prob
29+
prob
30+
end
31+
end
32+
33+
update!(p::TabularLearner, experience::Pair) = p.table[first(experience)] = last(experience)
34+

0 commit comments

Comments
 (0)