Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ after_success:

## uncomment the following lines to override the default test script
script:
- travis_wait 50 julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()'
- travis_wait 50 julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test(coverage=true)'
15 changes: 5 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
authors = ["Jun Tian <[email protected]>"]
version = "0.3.3"
version = "0.4.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -25,22 +25,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Adapt = "1, 2"
BSON = "0.2"
CUDAapi = "3, 4"
CuArrays = "1.7, 2"
Distributions = "0.22, 0.23"
FillArrays = "0.8"
Flux = "0.10"
GPUArrays = "2, 3, 4.0"
Flux = "0.11"
ImageTransformations = "0.8"
JLD = "0.10"
MacroTools = "0.5"
ProgressMeter = "1.2"
ReinforcementLearningBase = "0.7"
ReinforcementLearningBase = "0.8"
Setfield = "0.6"
StatsBase = "0.32, 0.33"
Zygote = "0.4"
julia = "1.3"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ provides some standard and reusable components defined by [**RLBase**](https://g

export RLCore

include("extensions/extensions.jl")
include("utils/utils.jl")
include("extensions/extensions.jl")
include("components/components.jl")
include("core/core.jl")

Expand Down
12 changes: 7 additions & 5 deletions src/components/agents/abstract_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ export AbstractAgent,
Testing

"""
(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -> action
(agent::AbstractAgent)(stage::AbstractStage, obs)
(agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env) -> action
(agent::AbstractAgent)(stage::AbstractStage, env)

Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action.
The main difference is that, we divide an experiment into the following stages:
Expand All @@ -43,7 +43,7 @@ PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE
| | | | | |
v | +-----+ v +-------+ v +-----+ | v
--------------------->+ env +------>+ agent +------->+ env +---> ... ------->......
| ^ +-----+ obs +-------+ action +-----+ ^ |
| ^ +-----+ +-------+ action +-----+ ^ |
| | | |
| +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ |
| |
Expand All @@ -66,10 +66,12 @@ const POST_EPISODE_STAGE = PostEpisodeStage()
const PRE_ACT_STAGE = PreActStage()
const POST_ACT_STAGE = PostActStage()

(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs)
function (agent::AbstractAgent)(stage::AbstractStage, obs) end
(agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env)
function (agent::AbstractAgent)(stage::AbstractStage, env) end

struct Training{T<:AbstractStage} end
Training(s::T) where {T<:AbstractStage} = Training{T}()
struct Testing{T<:AbstractStage} end
Testing(s::T) where {T<:AbstractStage} = Testing{T}()

Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),get(io, :max_depth, 10))
77 changes: 42 additions & 35 deletions src/components/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ Generally speaking, it does nothing but update the trajectory and policy appropr

- `policy`::[`AbstractPolicy`](@ref): the policy to use
- `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment
- `role=:DEFAULT_PLAYER`: used to distinguish different agents
- `role=RLBase.DEFAULT_PLAYER`: used to distinguish different agents
"""
Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent
policy::P
trajectory::T
role::R = :DEFAULT_PLAYER
trajectory::T = DummyTrajectory()
role::R = RLBase.DEFAULT_PLAYER
is_training::Bool = true
end

# avoid polluting trajectory
(agent::Agent)(obs) = agent.policy(obs)
(agent::Agent)(env) = agent.policy(env)

Flux.functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy

Expand Down Expand Up @@ -69,47 +69,54 @@ function Flux.testmode!(agent::Agent, mode = true)
testmode!(agent.policy, mode)
end

(agent::Agent)(stage::AbstractStage, obs) =
agent.is_training ? agent(Training(stage), obs) : agent(Testing(stage), obs)
(agent::Agent)(stage::AbstractStage, env) =
agent.is_training ? agent(Training(stage), env) : agent(Testing(stage), env)

(agent::Agent)(::Testing, obs) = nothing
(agent::Agent)(::Testing{PreActStage}, obs) = agent.policy(obs)
(agent::Agent)(::Testing, env) = nothing
(agent::Agent)(::Testing{PreActStage}, env) = agent.policy(env)

#####
# DummyTrajectory
#####

(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::AbstractStage, env) = nothing
(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::PreActStage, env) = agent.policy(env)

#####
# EpisodicCompactSARTSATrajectory
#####
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PreEpisodeStage},
obs,
env,
)
empty!(agent.trajectory)
nothing
end

function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PreActStage},
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end

function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PostActStage},
obs,
env,
)
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PostEpisodeStage},
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end
Expand All @@ -125,7 +132,7 @@ function (
}
)(
::Training{PreEpisodeStage},
obs,
env,
)
if length(agent.trajectory) > 0
pop!(agent.trajectory, :state, :action)
Expand All @@ -140,10 +147,10 @@ function (
}
)(
::Training{PreActStage},
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end
Expand All @@ -155,9 +162,9 @@ function (
}
)(
::Training{PostActStage},
obs,
env,
)
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

Expand All @@ -168,10 +175,10 @@ function (
}
)(
::Training{PostEpisodeStage},
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end
Expand All @@ -182,7 +189,7 @@ end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PreEpisodeStage},
obs,
env,
)
if length(agent.trajectory) > 0
pop!(agent.trajectory, :state, :action)
Expand All @@ -192,28 +199,28 @@ end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PreActStage},
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PostActStage},
obs,
env,
)
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PostEpisodeStage},
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end
18 changes: 9 additions & 9 deletions src/components/agents/dyna_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ get_role(agent::DynaAgent) = agent.role

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::PreEpisodeStage,
obs,
env,
)
empty!(agent.trajectory)
nothing
end

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::PreActStage,
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.model, agent.trajectory, agent.policy) # model learning
update!(agent.policy, agent.trajectory) # direct learning
update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning
Expand All @@ -55,18 +55,18 @@ end

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::PostActStage,
obs,
env,
)
push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs))
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::PostEpisodeStage,
obs,
env,
)
action = agent.policy(obs)
push!(agent.trajectory; state = get_state(obs), action = action)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.model, agent.trajectory, agent.policy) # model learning
update!(agent.policy, agent.trajectory) # direct learning
update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning
Expand Down
2 changes: 1 addition & 1 deletion src/components/approximators/abstract_approximator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export AbstractApproximator,
ApproximatorStyle, Q_APPROXIMATOR, QApproximator, V_APPROXIMATOR, VApproximator

"""
(app::AbstractApproximator)(obs)
(app::AbstractApproximator)(env)

An approximator is a functional object for value estimation.
It serves as a black box to provides an abstraction over different
Expand Down
6 changes: 3 additions & 3 deletions src/components/approximators/neural_network_approximator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Use a DNN model for value estimation.
# Keyword arguments

- `model`, a Flux based DNN model.
- `optimizer=Descent()`
- `optimizer=nothing`
"""
Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator
model::M
optimizer::O = Descent()
optimizer::O = nothing
end

(app::NeuralNetworkApproximator)(x) = app.model(x)
Expand Down Expand Up @@ -42,7 +42,7 @@ Flux.testmode!(app::NeuralNetworkApproximator, mode = true) = testmode!(app.mode

The `actor` part must return logits (*Do not use softmax in the last layer!*), and the `critic` part must return a state value.
"""
Base.@kwdef struct ActorCritic{A,C,O}
Base.@kwdef struct ActorCritic{A,C,O} <: AbstractApproximator
actor::A
critic::C
optimizer::O = ADAM()
Expand Down
2 changes: 1 addition & 1 deletion src/components/components.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include("preprocessors.jl")
include("processors.jl")
include("trajectories/trajectories.jl")
include("approximators/approximators.jl")
include("explorers/explorers.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/components/explorers/UCB_explorer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Flux.testmode!(p::UCBExplorer, mode = true) = p.is_training = !mode
- `seed`, set the seed of inner RNG.
- `is_training=true`, in training mode, time step and counter will not be updated.
"""
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, seed = nothing, is_training = true) =
UCBExplorer(c, fill(ϵ, na), 1, MersenneTwister(seed), is_training)
UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, rng = Random.GLOBAL_RNG, is_training = true) =
UCBExplorer(c, fill(ϵ, na), 1, rng, is_training)

@doc raw"""
(ucb::UCBExplorer)(values::AbstractArray)
Expand Down
Loading