Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit dec7391

Browse files
authored
Update dependencies (#90)
* sync * update RLBase to newer version * switch to CUDA.jl * fix more * fix tests * better printing * rename obs -> env * sync changes * allow setting max_depth when printing struct * minor fix * decrease max_depth * simplify printing * resolve comments * support RandomStartPolicy * fix #73 * ignore gpu related tests when CUDA is not available
1 parent c03fe0b commit dec7391

38 files changed

+344
-306
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ after_success:
1919

2020
## uncomment the following lines to override the default test script
2121
script:
22-
- travis_wait 50 julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'
22+
- travis_wait 50 julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test(coverage=true)'

Project.toml

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
33
authors = ["Jun Tian <[email protected]>"]
4-
version = "0.3.3"
4+
version = "0.4.0"
55

66
[deps]
7+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
9-
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
10-
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
10+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1313
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -25,22 +25,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2626

2727
[compat]
28-
Adapt = "1, 2"
2928
BSON = "0.2"
30-
CUDAapi = "3, 4"
31-
CuArrays = "1.7, 2"
3229
Distributions = "0.22, 0.23"
3330
FillArrays = "0.8"
34-
Flux = "0.10"
35-
GPUArrays = "2, 3, 4.0"
31+
Flux = "0.11"
3632
ImageTransformations = "0.8"
3733
JLD = "0.10"
3834
MacroTools = "0.5"
3935
ProgressMeter = "1.2"
40-
ReinforcementLearningBase = "0.7"
36+
ReinforcementLearningBase = "0.8"
4137
Setfield = "0.6"
4238
StatsBase = "0.32, 0.33"
43-
Zygote = "0.4"
4439
julia = "1.3"
4540

4641
[extras]

src/ReinforcementLearningCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ provides some standard and reusable components defined by [**RLBase**](https://g
1111

1212
export RLCore
1313

14-
include("extensions/extensions.jl")
1514
include("utils/utils.jl")
15+
include("extensions/extensions.jl")
1616
include("components/components.jl")
1717
include("core/core.jl")
1818

src/components/agents/abstract_agent.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ export AbstractAgent,
1616
Testing
1717

1818
"""
19-
(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -> action
20-
(agent::AbstractAgent)(stage::AbstractStage, obs)
19+
(agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env) -> action
20+
(agent::AbstractAgent)(stage::AbstractStage, env)
2121
2222
Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action.
2323
The main difference is that, we divide an experiment into the following stages:
@@ -43,7 +43,7 @@ PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE
4343
| | | | | |
4444
v | +-----+ v +-------+ v +-----+ | v
4545
--------------------->+ env +------>+ agent +------->+ env +---> ... ------->......
46-
| ^ +-----+ obs +-------+ action +-----+ ^ |
46+
| ^ +-----+ +-------+ action +-----+ ^ |
4747
| | | |
4848
| +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ |
4949
| |
@@ -66,10 +66,12 @@ const POST_EPISODE_STAGE = PostEpisodeStage()
6666
const PRE_ACT_STAGE = PreActStage()
6767
const POST_ACT_STAGE = PostActStage()
6868

69-
(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs)
70-
function (agent::AbstractAgent)(stage::AbstractStage, obs) end
69+
(agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env)
70+
function (agent::AbstractAgent)(stage::AbstractStage, env) end
7171

7272
struct Training{T<:AbstractStage} end
7373
Training(s::T) where {T<:AbstractStage} = Training{T}()
7474
struct Testing{T<:AbstractStage} end
7575
Testing(s::T) where {T<:AbstractStage} = Testing{T}()
76+
77+
Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),get(io, :max_depth, 10))

src/components/agents/agent.jl

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ Generally speaking, it does nothing but update the trajectory and policy appropr
1616
1717
- `policy`::[`AbstractPolicy`](@ref): the policy to use
1818
- `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment
19-
- `role=:DEFAULT_PLAYER`: used to distinguish different agents
19+
- `role=RLBase.DEFAULT_PLAYER`: used to distinguish different agents
2020
"""
2121
Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent
2222
policy::P
23-
trajectory::T
24-
role::R = :DEFAULT_PLAYER
23+
trajectory::T = DummyTrajectory()
24+
role::R = RLBase.DEFAULT_PLAYER
2525
is_training::Bool = true
2626
end
2727

2828
# avoid polluting trajectory
29-
(agent::Agent)(obs) = agent.policy(obs)
29+
(agent::Agent)(env) = agent.policy(env)
3030

3131
Flux.functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy
3232

@@ -69,47 +69,54 @@ function Flux.testmode!(agent::Agent, mode = true)
6969
testmode!(agent.policy, mode)
7070
end
7171

72-
(agent::Agent)(stage::AbstractStage, obs) =
73-
agent.is_training ? agent(Training(stage), obs) : agent(Testing(stage), obs)
72+
(agent::Agent)(stage::AbstractStage, env) =
73+
agent.is_training ? agent(Training(stage), env) : agent(Testing(stage), env)
7474

75-
(agent::Agent)(::Testing, obs) = nothing
76-
(agent::Agent)(::Testing{PreActStage}, obs) = agent.policy(obs)
75+
(agent::Agent)(::Testing, env) = nothing
76+
(agent::Agent)(::Testing{PreActStage}, env) = agent.policy(env)
77+
78+
#####
79+
# DummyTrajectory
80+
#####
81+
82+
(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::AbstractStage, env) = nothing
83+
(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::PreActStage, env) = agent.policy(env)
7784

7885
#####
7986
# EpisodicCompactSARTSATrajectory
8087
#####
8188
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
8289
::Training{PreEpisodeStage},
83-
obs,
90+
env,
8491
)
8592
empty!(agent.trajectory)
8693
nothing
8794
end
8895

8996
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
9097
::Training{PreActStage},
91-
obs,
98+
env,
9299
)
93-
action = agent.policy(obs)
94-
push!(agent.trajectory; state = get_state(obs), action = action)
100+
action = agent.policy(env)
101+
push!(agent.trajectory; state = get_state(env), action = action)
95102
update!(agent.policy, agent.trajectory)
96103
action
97104
end
98105

99106
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
100107
::Training{PostActStage},
101-
obs,
108+
env,
102109
)
103-
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
110+
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
104111
nothing
105112
end
106113

107114
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
108115
::Training{PostEpisodeStage},
109-
obs,
116+
env,
110117
)
111-
action = agent.policy(obs)
112-
push!(agent.trajectory; state = get_state(obs), action = action)
118+
action = agent.policy(env)
119+
push!(agent.trajectory; state = get_state(env), action = action)
113120
update!(agent.policy, agent.trajectory)
114121
action
115122
end
@@ -125,7 +132,7 @@ function (
125132
}
126133
)(
127134
::Training{PreEpisodeStage},
128-
obs,
135+
env,
129136
)
130137
if length(agent.trajectory) > 0
131138
pop!(agent.trajectory, :state, :action)
@@ -140,10 +147,10 @@ function (
140147
}
141148
)(
142149
::Training{PreActStage},
143-
obs,
150+
env,
144151
)
145-
action = agent.policy(obs)
146-
push!(agent.trajectory; state = get_state(obs), action = action)
152+
action = agent.policy(env)
153+
push!(agent.trajectory; state = get_state(env), action = action)
147154
update!(agent.policy, agent.trajectory)
148155
action
149156
end
@@ -155,9 +162,9 @@ function (
155162
}
156163
)(
157164
::Training{PostActStage},
158-
obs,
165+
env,
159166
)
160-
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
167+
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
161168
nothing
162169
end
163170

@@ -168,10 +175,10 @@ function (
168175
}
169176
)(
170177
::Training{PostEpisodeStage},
171-
obs,
178+
env,
172179
)
173-
action = agent.policy(obs)
174-
push!(agent.trajectory; state = get_state(obs), action = action)
180+
action = agent.policy(env)
181+
push!(agent.trajectory; state = get_state(env), action = action)
175182
update!(agent.policy, agent.trajectory)
176183
action
177184
end
@@ -182,7 +189,7 @@ end
182189

183190
function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
184191
::Training{PreEpisodeStage},
185-
obs,
192+
env,
186193
)
187194
if length(agent.trajectory) > 0
188195
pop!(agent.trajectory, :state, :action)
@@ -192,28 +199,28 @@ end
192199

193200
function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
194201
::Training{PreActStage},
195-
obs,
202+
env,
196203
)
197-
action = agent.policy(obs)
198-
push!(agent.trajectory; state = get_state(obs), action = action)
204+
action = agent.policy(env)
205+
push!(agent.trajectory; state = get_state(env), action = action)
199206
update!(agent.policy, agent.trajectory)
200207
action
201208
end
202209

203210
function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
204211
::Training{PostActStage},
205-
obs,
212+
env,
206213
)
207-
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
214+
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
208215
nothing
209216
end
210217

211218
function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
212219
::Training{PostEpisodeStage},
213-
obs,
220+
env,
214221
)
215-
action = agent.policy(obs)
216-
push!(agent.trajectory; state = get_state(obs), action = action)
222+
action = agent.policy(env)
223+
push!(agent.trajectory; state = get_state(env), action = action)
217224
update!(agent.policy, agent.trajectory)
218225
action
219226
end

src/components/agents/dyna_agent.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ get_role(agent::DynaAgent) = agent.role
3535

3636
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
3737
::PreEpisodeStage,
38-
obs,
38+
env,
3939
)
4040
empty!(agent.trajectory)
4141
nothing
4242
end
4343

4444
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
4545
::PreActStage,
46-
obs,
46+
env,
4747
)
48-
action = agent.policy(obs)
49-
push!(agent.trajectory; state = get_state(obs), action = action)
48+
action = agent.policy(env)
49+
push!(agent.trajectory; state = get_state(env), action = action)
5050
update!(agent.model, agent.trajectory, agent.policy) # model learning
5151
update!(agent.policy, agent.trajectory) # direct learning
5252
update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning
@@ -55,18 +55,18 @@ end
5555

5656
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
5757
::PostActStage,
58-
obs,
58+
env,
5959
)
60-
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
60+
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
6161
nothing
6262
end
6363

6464
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
6565
::PostEpisodeStage,
66-
obs,
66+
env,
6767
)
68-
action = agent.policy(obs)
69-
push!(agent.trajectory; state = get_state(obs), action = action)
68+
action = agent.policy(env)
69+
push!(agent.trajectory; state = get_state(env), action = action)
7070
update!(agent.model, agent.trajectory, agent.policy) # model learning
7171
update!(agent.policy, agent.trajectory) # direct learning
7272
update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning

src/components/approximators/abstract_approximator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export AbstractApproximator,
22
ApproximatorStyle, Q_APPROXIMATOR, QApproximator, V_APPROXIMATOR, VApproximator
33

44
"""
5-
(app::AbstractApproximator)(obs)
5+
(app::AbstractApproximator)(env)
66
77
An approximator is a functional object for value estimation.
88
It serves as a black box to provides an abstraction over different

src/components/approximators/neural_network_approximator.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ Use a DNN model for value estimation.
1010
# Keyword arguments
1111
1212
- `model`, a Flux based DNN model.
13-
- `optimizer=Descent()`
13+
- `optimizer=nothing`
1414
"""
1515
Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator
1616
model::M
17-
optimizer::O = Descent()
17+
optimizer::O = nothing
1818
end
1919

2020
(app::NeuralNetworkApproximator)(x) = app.model(x)
@@ -42,7 +42,7 @@ Flux.testmode!(app::NeuralNetworkApproximator, mode = true) = testmode!(app.mode
4242
4343
The `actor` part must return logits (*Do not use softmax in the last layer!*), and the `critic` part must return a state value.
4444
"""
45-
Base.@kwdef struct ActorCritic{A,C,O}
45+
Base.@kwdef struct ActorCritic{A,C,O} <: AbstractApproximator
4646
actor::A
4747
critic::C
4848
optimizer::O = ADAM()

src/components/components.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
include("preprocessors.jl")
1+
include("processors.jl")
22
include("trajectories/trajectories.jl")
33
include("approximators/approximators.jl")
44
include("explorers/explorers.jl")

src/components/explorers/UCB_explorer.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ Flux.testmode!(p::UCBExplorer, mode = true) = p.is_training = !mode
2323
- `seed`, set the seed of inner RNG.
2424
- `is_training=true`, in training mode, time step and counter will not be updated.
2525
"""
26-
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, seed = nothing, is_training = true) =
27-
UCBExplorer(c, fill(ϵ, na), 1, MersenneTwister(seed), is_training)
26+
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, rng = Random.GLOBAL_RNG, is_training = true) =
27+
UCBExplorer(c, fill(ϵ, na), 1, rng, is_training)
2828

2929
@doc raw"""
3030
(ucb::UCBExplorer)(values::AbstractArray)

0 commit comments

Comments
 (0)