Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ BSON = "0.2"
CUDA = "1"
Distributions = "0.22, 0.23"
FillArrays = "0.8, 0.9"
Flux = "0.11"
Flux = "0.11.1"
GPUArrays = "5"
ImageTransformations = "0.8"
JLD = "0.10"
Expand Down
139 changes: 24 additions & 115 deletions src/components/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Generally speaking, it does nothing but update the trajectory and policy appropr
"""
Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent
policy::P
trajectory::T = DummyTrajectory()
trajectory::T = DUMMY_TRAJECTORY
role::R = RLBase.DEFAULT_PLAYER
is_training::Bool = true
end
Expand Down Expand Up @@ -84,144 +84,53 @@ end
agent.policy(env)

#####
# EpisodicCompactSARTSATrajectory
# default behavior
#####
function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PreEpisodeStage},
env,
)
empty!(agent.trajectory)
nothing
end

function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PreActStage},
env,
)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end

function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PostActStage},
env,
)
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
::Training{PostEpisodeStage},
env,
)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end

#####
# Union{CircularCompactSARTSATrajectory, CircularCompactPSARTSATrajectory}
#####

function (
agent::Agent{
<:AbstractPolicy,
<:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory},
}
)(
::Training{PreEpisodeStage},
env,
)
if length(agent.trajectory) > 0
pop!(agent.trajectory, :state, :action)
function (agent::Agent)(::Training{PreEpisodeStage}, env)
if nframes(agent.trajectory[:full_state]) > 0
pop!(agent.trajectory, :full_state)
end
if nframes(agent.trajectory[:full_action]) > 0
pop!(agent.trajectory, :full_action)
end
if ActionStyle(env) === FULL_ACTION_SET && nframes(agent.trajectory[:full_legal_actions_mask]) > 0
pop!(agent.trajectory, :full_legal_actions_mask)
end
nothing
end

function (
agent::Agent{
<:AbstractPolicy,
<:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory},
}
)(
::Training{PreActStage},
env,
)
function (agent::Agent)(::Training{PreActStage}, env)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
if ActionStyle(env) === FULL_ACTION_SET
push!(agent.trajectory; legal_actions_mask=get_legal_actions_mask(env))
end
update!(agent.policy, agent.trajectory)
action
end

function (
agent::Agent{
<:AbstractPolicy,
<:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory},
}
)(
::Training{PostActStage},
env,
)
function (agent::Agent)(::Training{PostActStage}, env)
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

function (
agent::Agent{
<:AbstractPolicy,
<:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory},
}
)(
::Training{PostEpisodeStage},
env,
)
function (agent::Agent)(::Training{PostEpisodeStage}, env)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
if ActionStyle(env) === FULL_ACTION_SET
push!(agent.trajectory; legal_actions_mask=get_legal_actions_mask(env))
end
update!(agent.policy, agent.trajectory)
action
end

#####
# VectorialCompactSARTSATrajectory
# EpisodicCompactSARTSATrajectory
#####

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
function (agent::Agent{<:AbstractPolicy,<:EpisodicTrajectory})(
::Training{PreEpisodeStage},
env,
)
if length(agent.trajectory) > 0
pop!(agent.trajectory, :state, :action)
end
nothing
end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PreActStage},
env,
)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PostActStage},
env,
)
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
empty!(agent.trajectory)
nothing
end

function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})(
::Training{PostEpisodeStage},
env,
)
action = agent.policy(env)
push!(agent.trajectory; state = get_state(env), action = action)
update!(agent.policy, agent.trajectory)
action
end
end
8 changes: 4 additions & 4 deletions src/components/agents/dyna_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ end

get_role(agent::DynaAgent) = agent.role

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(
::PreEpisodeStage,
env,
)
empty!(agent.trajectory)
nothing
end

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(
::PreActStage,
env,
)
Expand All @@ -53,15 +53,15 @@ function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
action
end

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(
::PostActStage,
env,
)
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
nothing
end

function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})(
function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})(
::PostEpisodeStage,
env,
)
Expand Down
97 changes: 27 additions & 70 deletions src/components/trajectories/abstract_trajectory.jl
Original file line number Diff line number Diff line change
@@ -1,100 +1,57 @@
export AbstractTrajectory, get_trace, RTSA, SARTSA
export AbstractTrajectory

"""
AbstractTrajectory{names,types} <: AbstractArray{NamedTuple{names,types},1}
AbstractTrajectory

A trajectory is used to record some useful information
A trace is used to record some useful information
during the interactions between agents and environments.

# Parameters
- `names`::`NTuple{Symbol}`, indicate what fields to be recorded.
- `types`::`Tuple{DataType...}`, the datatypes of `names`.

The length of `names` and `types` must match.

Required Methods:

- [`get_trace`](@ref)
- `Base.haskey(t::AbstractTrajectory, s::Symbol)`
- `Base.getindex(t::AbstractTrajectory, s::Symbol)`
- `Base.keys(t::AbstractTrajectory)`
- `Base.push!(t::AbstractTrajectory, kv::Pair{Symbol})`
- `Base.pop!(t::AbstractTrajectory, s::Symbol)`
- `Base.empty!(t::AbstractTrajectory)`

Optional Methods:

- `Base.length`
- `Base.size`
- `Base.lastindex`
- `Base.isempty`
- `Base.empty!`
"""
abstract type AbstractTrajectory{names,types} <: AbstractArray{NamedTuple{names,types},1} end

# some typical trace names
"An alias of `(:reward, :terminal, :state, :action)`"
const RTSA = (:reward, :terminal, :state, :action)

"An alias of `(:state, :action, :reward, :terminal, :next_state, :next_action)`"
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)

"""
get_trace(t::AbstractTrajectory, s::NTuple{N,Symbol}) where {N}
"""
get_trace(t::AbstractTrajectory, s::NTuple{N,Symbol}) where {N} =
NamedTuple{s}(get_trace(t, x) for x in s)

"""
get_trace(t::AbstractTrajectory, s::Symbol...)
"""
get_trace(t::AbstractTrajectory, s::Symbol...) = get_trace(t, s)
- `isfull`

"""
get_trace(t::AbstractTrajectory{names}) where {names}
"""
get_trace(t::AbstractTrajectory{names}) where {names} =
NamedTuple{names}(get_trace(t, x) for x in names)
abstract type AbstractTrajectory end

Base.length(t::AbstractTrajectory) = maximum(length(x) for x in get_trace(t))
Base.size(t::AbstractTrajectory) = (length(t),)
Base.lastindex(t::AbstractTrajectory) = length(t)
Base.getindex(t::AbstractTrajectory{names,types}, i::Int) where {names,types} =
NamedTuple{names,types}(Tuple(x[i] for x in get_trace(t)))

Base.isempty(t::AbstractTrajectory) = all(isempty(t) for t in get_trace(t))

function Base.empty!(t::AbstractTrajectory)
for x in get_trace(t)
empty!(x)
end
end

"""
Base.push!(t::AbstractTrajectory; kwargs...)
"""
function Base.push!(t::AbstractTrajectory; kwargs...)
function Base.push!(t::AbstractTrajectory;kwargs...)
for kv in kwargs
push!(t, kv)
end
end

"""
Base.pop!(t::AbstractTrajectory{names}) where {names}
`pop!` out one element of each trace in `t`
"""
function Base.pop!(t::AbstractTrajectory{names}) where {names}
pop!(t, names...)
end

"""
Base.pop!(t::AbstractTrajectory, s::Symbol...)

`pop!` out one element of the traces specified in `s`
"""
function Base.pop!(t::AbstractTrajectory, s::Symbol...)
function Base.pop!(t::AbstractTrajectory, s::Tuple{Vararg{Symbol}})
NamedTuple{s}(pop!(t, x) for x in s)
end

Base.pop!(t::AbstractTrajectory) = pop!(t, keys(t))

function Base.empty!(t::AbstractTrajectory)
for s in keys(t)
empty!(t[s])
end
end

#####
# patch code
#####

# avoid showing the inner structure
function AbstractTrees.children(t::StructTree{<:AbstractTrajectory})
traces = get_trace(t.x)
Tuple(k => StructTree(v) for (k, v) in pairs(traces))
Tuple(k => StructTree(t.x[k]) for k in keys(t.x))
end

Base.summary(io::IO, t::T) where {T<:AbstractTrajectory} =
print(io, "$(length(t))-element $(T.name)")
@deprecate get_trace(t::AbstractTrajectory, s::Symbol) t[s]
Loading