From 0b9d25efa3c7a7b3593649def2f4c323bcea20b2 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 11 Aug 2020 16:10:49 +0800 Subject: [PATCH 1/9] simplify traces --- src/components/traces/abstract_trace.jl | 46 +++ src/components/traces/trace.jl | 291 ++++++++++++++++++ src/components/traces/traces.jl | 2 + .../trajectories/abstract_trajectory.jl | 100 ------ .../circular_compact_PSARTSA_buffer.jl | 86 ------ .../circular_compact_SARTSA_buffer.jl | 55 ---- .../trajectories/circular_trajectory.jl | 68 ---- src/components/trajectories/common.jl | 78 ----- .../trajectories/dummy_trajectory.jl | 5 - .../episodic_compact_SARTSA_buffer.jl | 30 -- src/components/trajectories/trajectories.jl | 10 - src/components/trajectories/trajectory.jl | 23 -- .../vectorial_compact_SARTSA_buffer.jl | 66 ---- .../trajectories/vectorial_trajectory.jl | 47 --- src/utils/circular_array_buffer.jl | 2 + src/utils/sum_tree.jl | 2 +- test/components/traces.jl | 129 ++++++++ test/components/trajectories.jl | 241 --------------- 18 files changed, 471 insertions(+), 810 deletions(-) create mode 100644 src/components/traces/abstract_trace.jl create mode 100644 src/components/traces/trace.jl create mode 100644 src/components/traces/traces.jl delete mode 100644 src/components/trajectories/abstract_trajectory.jl delete mode 100644 src/components/trajectories/circular_compact_PSARTSA_buffer.jl delete mode 100644 src/components/trajectories/circular_compact_SARTSA_buffer.jl delete mode 100644 src/components/trajectories/circular_trajectory.jl delete mode 100644 src/components/trajectories/common.jl delete mode 100644 src/components/trajectories/dummy_trajectory.jl delete mode 100644 src/components/trajectories/episodic_compact_SARTSA_buffer.jl delete mode 100644 src/components/trajectories/trajectories.jl delete mode 100644 src/components/trajectories/trajectory.jl delete mode 100644 src/components/trajectories/vectorial_compact_SARTSA_buffer.jl delete mode 100644 src/components/trajectories/vectorial_trajectory.jl create mode 100644 test/components/traces.jl delete mode 100644 test/components/trajectories.jl diff --git a/src/components/traces/abstract_trace.jl b/src/components/traces/abstract_trace.jl new file mode 100644 index 0000000..2d05faa --- /dev/null +++ b/src/components/traces/abstract_trace.jl @@ -0,0 +1,46 @@ +export AbstractTrace + +""" + AbstractTrace + +A trace is used to record some useful information +during the interactions between agents and environments. + +Required Methods: + +- `Base.haskey(t::AbstractTrace, s::Symbol)` +- `Base.getproperty(t::AbstractTrace, s::Symbol)` +- `Base.keys(t::AbstractTrace)` +- `Base.push!(t::AbstractTrace, kv::Pair{Symbol})` +- `Base.pop!(t::AbstractTrace, s::Symbol)` + +Optional Methods: + +- `isfull` +- `empty!` + +""" +abstract type AbstractTrace end + +function Base.push!(t::AbstractTrace;kwargs...) + for kv in kwargs + push!(t, kv) + end +end + +""" + Base.pop!(t::AbstractTrace, s::Symbol...) + +`pop!` out one element of the traces specified in `s` +""" +function Base.pop!(t::AbstractTrace, s::Tuple{Vararg{Symbol}}) + NamedTuple{s}(pop!(t, x) for x in s) +end + +Base.pop!(t::AbstractTrace) = pop!(t, keys(t)) + +function Base.empty!(t::AbstractTrace) + for s in keys(t) + empty!(t[s]) + end +end \ No newline at end of file diff --git a/src/components/traces/trace.jl b/src/components/traces/trace.jl new file mode 100644 index 0000000..c651697 --- /dev/null +++ b/src/components/traces/trace.jl @@ -0,0 +1,291 @@ +using MacroTools:@forward + +##### +# Trace +##### + +""" + Trace(;[trace_name=trace_container]...) + +Simply a wrapper of `NamedTuple`. +Define our own type here to avoid type piracy with `NamedTuple` +""" +struct Trace{T} <: AbstractTrace + traces::T +end + +Trace(;kwargs...) = Trace(kwargs.data) + +@forward Trace.traces Base.keys, Base.haskey, Base.getindex + +Base.push!(t::Trace, kv::Pair{Symbol}) = push!(t[first(kv)], last(kv)) +Base.pop!(t::Trace, s::Symbol) = pop!(t[s]) + +##### +# SharedTrace +##### + +struct SharedTraceMeta + start_shift::Int + end_shift::Int +end + +""" + SharedTrace(trace;[trace_name=start_shift:end_shift]...) + +Create multiple traces sharing the same underlying container. +""" +struct SharedTrace{X,M} <: AbstractTrace + x::X + meta::M +end + +function SharedTrace(x, s::Symbol) + SharedTrace( + x, + (; + s=>SharedTraceMeta(1, -1), + Symbol(:next_, s)=>SharedTraceMeta(2, 0), + Symbol(:full_, s) => SharedTraceMeta(1,0) + ) + ) +end + +@forward SharedTrace.meta Base.keys, Base.haskey + +function Base.getindex(t::SharedTrace, s::Symbol) + m = t.meta[s] + select_last_dim(t.x, m.start_shift:(nframes(t.x)+m.end_shift)) +end + +Base.push!(t::SharedTrace, kv::Pair{Symbol}) = push!(t.x, last(kv)) +Base.empty!(t::SharedTrace) = empty!(t.x) +Base.pop!(t::SharedTrace, s::Symbol) = pop!(t.x) + +function Base.pop!(t::SharedTrace) + s = first(keys(t)) + (;s => pop!(t.x)) +end + +##### +# EpisodicTrace +##### + +""" +Assuming that the `flag_trace` is in `traces` and it's an `AbstractVector{Bool}`, +meaning whether an environment reaches terminal or not. The last element in +`flag_trace` will be used to determine whether the whole trace is full or not. +""" +struct EpisodicTrace{T, flag_trace} <: AbstractTrace + traces::T +end + +EpisodicTrace(traces::T, flag_trace=:terminal) where T = EpisodicTrace{T, flag_trace}(traces) + +@forward EpisodicTrace.traces Base.keys, Base.haskey, Base.getindex, Base.push!, Base.pop!, Base.empty! + +function isfull(t::EpisodicTrace{<:Any, F}) where F + x = t.traces[F] + (nframes(x) > 0) && select_last_frame(x) +end + +##### +# CombinedTrace +##### + +struct CombinedTrace{T1, T2} <: AbstractTrace + t1::T1 + t2::T2 +end + +Base.haskey(t::CombinedTrace, s::Symbol) = haskey(t.t1, s) || haskey(t.t2, s) +Base.getindex(t::CombinedTrace, s::Symbol) = if haskey(t.t1, s) + getindex(t.t1, s) +elseif haskey(t.t2, s) + getindex(t.t2, s) +else + throw(ArgumentError("unknown key: $s")) +end + +Base.keys(t::CombinedTrace) = (keys(t.t1)..., keys(t.t2)...) + +Base.push!(t::CombinedTrace, kv::Pair{Symbol}) = if haskey(t.t1, first(kv)) + push!(t.t1, kv) +elseif haskey(t.t2, first(kv)) + push!(t.t2, kv) +else + throw(ArgumentError("unknown kv: $kv")) +end + +Base.pop!(t::CombinedTrace, s::Symbol) = if haskey(t.t1, s) + pop!(t.t1, s) +elseif haskey(t.t2, s) + pop!(t.t2, s) +else + throw(ArgumentError("unknown key: $s")) +end + +Base.pop!(t::CombinedTrace) = merge(pop!(t.t1), pop!(t.t2)) + +function Base.empty!(t::CombinedTrace) + empty!(t.t1) + empty!(t.t2) +end + +##### +# CircularCompactSATrace +##### + +const CircularCompactSATrace = CombinedTrace{ + <:SharedTrace{<:CircularArrayBuffer, <:NamedTuple{(:state, :next_state, :full_state)}}, + <:SharedTrace{<:CircularArrayBuffer, <:NamedTuple{(:action, :next_action, :full_action)}}, +} + +function CircularCompactSATrace(; + capacity, + state_type = Int, + state_size = (), + action_type = Int, + action_size = (), +) + CombinedTrace( + SharedTrace( + CircularArrayBuffer{state_type}(state_size..., capacity+1), + :state), + SharedTrace( + CircularArrayBuffer{action_type}(action_size..., capacity+1), + :action + ), + ) +end + +##### +# CircularCompactSALTrace +##### + +const CircularCompactSALTrace = CombinedTrace{ + <:SharedTrace{<:CircularArrayBuffer, <:NamedTuple{(:legal_actions_mask, :next_legal_actions_mask, :full_legal_actions_mask)}}, + <:CircularCompactSATrace +} + +function CircularCompactSALTrace(; + capacity, + legal_actions_mask_size, + legal_actions_mask_type=Bool, + kw... +) + CombinedTrace( + SharedTrace( + CircularArrayBuffer{legal_actions_mask_type}(legal_actions_mask_size..., capacity+1), + :legal_actions_mask + ), + CircularCompactSATrace(;capacity=capacity, kw...) + ) +end +##### +# CircularCompactSARTSATrace +##### + +const CircularCompactSARTSATrace = CombinedTrace{ + <:Trace{<:NamedTuple{(:reward, :terminal), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer}}}, + <:CircularCompactSATrace +} + +function CircularCompactSARTSATrace(; + capacity, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... +) + CombinedTrace( + Trace( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + ), + CircularCompactSATrace(;capacity=capacity, kw...), + ) +end + +##### +# CircularCompactSALRTSALTrace +##### + +const CircularCompactSALRTSALTrace = CombinedTrace{ + <:Trace{<:NamedTuple{(:reward, :terminal), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer}}}, + <:CircularCompactSALTrace +} + +function CircularCompactSALRTSALTrace(; + capacity, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... + ) + CombinedTrace( + Trace( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + ), + CircularCompactSALTrace(;capacity=capacity, kw...), + ) +end + +##### +# CircularCompactPSARTSATrace +##### + +const CircularCompactPSARTSATrace = CombinedTrace{ + <:Trace{<:NamedTuple{(:reward, :terminal,:priority), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:SumTree}}}, + <:CircularCompactSATrace +} + +function CircularCompactPSARTSATrace(; + capacity, + priority_type=Float32, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... +) + CombinedTrace( + Trace( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + priority=SumTree(priority_type, capacity) + ), + CircularCompactSATrace(;capacity=capacity, kw...), + ) +end + +##### +# CircularCompactPSALRTSALTrace +##### + +const CircularCompactPSALRTSALTrace = CombinedTrace{ + <:Trace{<:NamedTuple{(:reward, :terminal,:priority), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:SumTree}}}, + <:CircularCompactSALTrace +} + +function CircularCompactPSALRTSALTrace(; + capacity, + priority_type=Float32, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... +) + CombinedTrace( + Trace( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + priority=SumTree(priority_type, capacity) + ), + CircularCompactSALTrace(;capacity=capacity, kw...), + ) +end \ No newline at end of file diff --git a/src/components/traces/traces.jl b/src/components/traces/traces.jl new file mode 100644 index 0000000..1bc7e12 --- /dev/null +++ b/src/components/traces/traces.jl @@ -0,0 +1,2 @@ +include("abstract_trace.jl") +include("trace.jl") \ No newline at end of file diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl deleted file mode 100644 index 0dc6d08..0000000 --- a/src/components/trajectories/abstract_trajectory.jl +++ /dev/null @@ -1,100 +0,0 @@ -export AbstractTrajectory, get_trace, RTSA, SARTSA - -""" - AbstractTrajectory{names,types} <: AbstractArray{NamedTuple{names,types},1} - -A trajectory is used to record some useful information -during the interactions between agents and environments. - -# Parameters -- `names`::`NTuple{Symbol}`, indicate what fields to be recorded. -- `types`::`Tuple{DataType...}`, the datatypes of `names`. - -The length of `names` and `types` must match. - -Required Methods: - -- [`get_trace`](@ref) -- `Base.push!(t::AbstractTrajectory, kv::Pair{Symbol})` -- `Base.pop!(t::AbstractTrajectory, s::Symbol)` - -Optional Methods: - -- `Base.length` -- `Base.size` -- `Base.lastindex` -- `Base.isempty` -- `Base.empty!` -""" -abstract type AbstractTrajectory{names,types} <: AbstractArray{NamedTuple{names,types},1} end - -# some typical trace names -"An alias of `(:reward, :terminal, :state, :action)`" -const RTSA = (:reward, :terminal, :state, :action) - -"An alias of `(:state, :action, :reward, :terminal, :next_state, :next_action)`" -const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) - -""" - get_trace(t::AbstractTrajectory, s::NTuple{N,Symbol}) where {N} -""" -get_trace(t::AbstractTrajectory, s::NTuple{N,Symbol}) where {N} = - NamedTuple{s}(get_trace(t, x) for x in s) - -""" - get_trace(t::AbstractTrajectory, s::Symbol...) -""" -get_trace(t::AbstractTrajectory, s::Symbol...) = get_trace(t, s) - -""" - get_trace(t::AbstractTrajectory{names}) where {names} -""" -get_trace(t::AbstractTrajectory{names}) where {names} = - NamedTuple{names}(get_trace(t, x) for x in names) - -Base.length(t::AbstractTrajectory) = maximum(length(x) for x in get_trace(t)) -Base.size(t::AbstractTrajectory) = (length(t),) -Base.lastindex(t::AbstractTrajectory) = length(t) -Base.getindex(t::AbstractTrajectory{names,types}, i::Int) where {names,types} = - NamedTuple{names,types}(Tuple(x[i] for x in get_trace(t))) - -Base.isempty(t::AbstractTrajectory) = all(isempty(t) for t in get_trace(t)) - -function Base.empty!(t::AbstractTrajectory) - for x in get_trace(t) - empty!(x) - end -end - -""" - Base.push!(t::AbstractTrajectory; kwargs...) -""" -function Base.push!(t::AbstractTrajectory; kwargs...) - for kv in kwargs - push!(t, kv) - end -end - -""" - Base.pop!(t::AbstractTrajectory{names}) where {names} -`pop!` out one element of each trace in `t` -""" -function Base.pop!(t::AbstractTrajectory{names}) where {names} - pop!(t, names...) -end - -""" - Base.pop!(t::AbstractTrajectory, s::Symbol...) -`pop!` out one element of the traces specified in `s` -""" -function Base.pop!(t::AbstractTrajectory, s::Symbol...) - NamedTuple{s}(pop!(t, x) for x in s) -end - -function AbstractTrees.children(t::StructTree{<:AbstractTrajectory}) - traces = get_trace(t.x) - Tuple(k => StructTree(v) for (k, v) in pairs(traces)) -end - -Base.summary(io::IO, t::T) where {T<:AbstractTrajectory} = - print(io, "$(length(t))-element $(T.name)") diff --git a/src/components/trajectories/circular_compact_PSARTSA_buffer.jl b/src/components/trajectories/circular_compact_PSARTSA_buffer.jl deleted file mode 100644 index 861a09a..0000000 --- a/src/components/trajectories/circular_compact_PSARTSA_buffer.jl +++ /dev/null @@ -1,86 +0,0 @@ -export CircularCompactPSARTSATrajectory - -using MacroTools: @forward - -struct CircularCompactPSARTSATrajectory{T<:CircularCompactSARTSATrajectory,P,names,types} <: - AbstractTrajectory{names,types} - trajectory::T - priority::P -end - -""" - CircularCompactPSARTSATrajectory(;kwargs) - -Similar to [`CircularCompactSARTSATrajectory`](@ref), except that another trace named `priority` is added. - -# Key word arguments - -- `capacity::Int`, the maximum length of each trace. -- `state_type = Int` -- `state_size = ()` -- `action_type = Int` -- `action_size = ()` -- `reward_type = Float32` -- `reward_size = ()` -- `terminal_type = Bool` -- `terminal_size = ()` -- `priority_type = Float32` -""" -function CircularCompactPSARTSATrajectory(; priority_type = Float32, kw...) - t = CircularCompactSARTSATrajectory(; kw...) - p = SumTree(priority_type, kw.data.capacity) - names = typeof(t).parameters[1] - types = typeof(t).parameters[2] - CircularCompactPSARTSATrajectory{ - typeof(t), - typeof(p), - (names..., :priority), - Tuple{types.parameters...,eltype(p)}, - }( - t, - p, - ) -end - -@forward CircularCompactPSARTSATrajectory.trajectory Base.length, Base.isempty - -get_trace(t::CircularCompactPSARTSATrajectory, s::Symbol) = - s == :priority ? t.priority : get_trace(t.trajectory, s) - -function Base.getindex(b::CircularCompactPSARTSATrajectory, i::Int) - ( - state = select_last_dim(b.trajectory[:state], i), - action = select_last_dim(b.trajectory[:action], i), - reward = select_last_dim(b.trajectory[:reward], i), - terminal = select_last_dim(b.trajectory[:terminal], i), - next_state = select_last_dim(b.trajectory[:state], i + 1), - next_action = select_last_dim(b.trajectory[:action], i + 1), - priority = select_last_dim(b.priority, i), - ) -end - -function Base.empty!(b::CircularCompactPSARTSATrajectory) - empty!(b.priority) - empty!(b.trajectory) -end - -function Base.push!(b::CircularCompactPSARTSATrajectory, kv::Pair{Symbol}) - k, v = kv - if k == :priority - push!(b.priority, v) - else - push!(b.trajectory, kv) - end -end - -function Base.pop!(t::CircularCompactPSARTSATrajectory, s::Symbol) - if s == :priority - pop!(t.priority) - else - pop!(t.trajectory, s) - end -end - -function Base.pop!(t::CircularCompactPSARTSATrajectory) - (pop!(t.trajectory)..., priority = pop!(t.priority)) -end diff --git a/src/components/trajectories/circular_compact_SARTSA_buffer.jl b/src/components/trajectories/circular_compact_SARTSA_buffer.jl deleted file mode 100644 index 42a1eea..0000000 --- a/src/components/trajectories/circular_compact_SARTSA_buffer.jl +++ /dev/null @@ -1,55 +0,0 @@ -export CircularCompactSARTSATrajectory - -const CircularCompactSARTSATrajectory = Trajectory{ - SARTSA, - T1, - NamedTuple{RTSA,T2}, -} where {T1,T2<:Tuple{Vararg{<:CircularArrayBuffer}}} - -""" - CircularCompactSARTSATrajectory(;kwargs...) - -Similar to [`VectorialCompactSARTSATrajectory`](@ref), -instead of using `Vector`s as containers, [`CircularArrayBuffer`](@ref)s are used here. - -# Key word arguments - -- `capacity`::Int, the maximum length of each trace. -- `state_type` = Int -- `state_size` = () -- `action_type` = Int -- `action_size` = () -- `reward_type` = Float32 -- `reward_size` = () -- `terminal_type` = Bool -- `terminal_size` = () -""" -function CircularCompactSARTSATrajectory(; - capacity, - state_type = Int, - state_size = (), - action_type = Int, - action_size = (), - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), -) - capacity > 0 || throw(ArgumentError("capacity must > 0")) - reward = CircularArrayBuffer{reward_type}(reward_size..., capacity) - terminal = CircularArrayBuffer{terminal_type}(terminal_size..., capacity) - state = CircularArrayBuffer{state_type}(state_size..., capacity + 1) - action = CircularArrayBuffer{action_type}(action_size..., capacity + 1) - ts = NamedTuple{RTSA}((reward, terminal, state, action)) - - CircularCompactSARTSATrajectory{ - Tuple{frame_type(state),frame_type(action),map(frame_type, ts)...}, - typeof(ts).parameters[2], - }( - ts, - ) -end - -isfull(t::CircularCompactSARTSATrajectory) = isfull(t[:action]) - -Base.length(t::CircularCompactSARTSATrajectory) = nframes(t[:terminal]) diff --git a/src/components/trajectories/circular_trajectory.jl b/src/components/trajectories/circular_trajectory.jl deleted file mode 100644 index 37eb5df..0000000 --- a/src/components/trajectories/circular_trajectory.jl +++ /dev/null @@ -1,68 +0,0 @@ -export CircularTrajectory - -const CircularTrajectory = Trajectory{ - names, - types, - NamedTuple{names,trace_types}, -} where {names,types,trace_types<:Tuple{Vararg{<:CircularArrayBuffer}}} - - -""" - CircularTrajectory(; capacity, trace_name=eltype=>size...) - -Similar to [`VectorialTrajectory`](@ref), but we use the -[`CircularArrayBuffer`](@ref) to store the traces. The `capacity` -here is used to specify the maximum length of the trajectory. - -# Example - -```julia-repl -julia> t = CircularTrajectory(capacity=10, state=Float64=>(3,3), reward=Int=>tuple()) -0-element Trajectory{(:state, :reward),Tuple{Float64,Int64},NamedTuple{(:state, :reward),Tuple{CircularArrayBuffer{Float64,3},CircularArrayBuffer{Int64,1}}}} - -julia> push!(t,state=rand(3,3), reward=1) - -julia> push!(t,state=rand(3,3), reward=2) - -julia> get_trace(t, :reward) -2-element CircularArrayBuffer{Int64,1}: - 1 - 2 - -julia> get_trace(t, :state) -3×3×2 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 0.699906 0.382396 0.927411 - 0.269807 0.0581324 0.239609 - 0.222304 0.514408 0.318905 - -[:, :, 2] = - 0.956228 0.992505 0.109743 - 0.763497 0.381387 0.540566 - 0.223081 0.834308 0.634759 - -julia> pop!(t) - -julia> get_trace(t, :state) -3×3×1 CircularArrayBuffer{Float64,3}: -[:, :, 1] = - 0.699906 0.382396 0.927411 - 0.269807 0.0581324 0.239609 - 0.222304 0.514408 0.318905 - -``` -""" -function CircularTrajectory(; capacity, kwargs...) - names = keys(kwargs.data) - types_and_sizes = values(kwargs.data) - types = (t for (t, s) in types_and_sizes) - sizes = (s for (t, s) in types_and_sizes) - trajectories = merge( - NamedTuple(), - (name, CircularArrayBuffer{t}(s..., capacity)) - for (name, (t, s)) in zip(names, types_and_sizes) - ) - CircularTrajectory{names,Tuple{types...},typeof(values(trajectories))}(trajectories) -end - -Base.length(t::CircularTrajectory) = maximum(nframes(x) for x in get_trace(t)) diff --git a/src/components/trajectories/common.jl b/src/components/trajectories/common.jl deleted file mode 100644 index babd0aa..0000000 --- a/src/components/trajectories/common.jl +++ /dev/null @@ -1,78 +0,0 @@ -const CompactSARTSATrajectory = - Union{CircularCompactSARTSATrajectory,VectorialCompactSARTSATrajectory} - -function get_trace(b::CompactSARTSATrajectory, s::Symbol) - if s == :state || s == :action - select_last_dim(b[s], 1:(nframes(b[s]) > 1 ? nframes(b[s]) - 1 : nframes(b[s]))) - elseif s == :reward || s == :terminal - b[s] - elseif s == :next_state - select_last_dim(b[:state], 2:nframes(b[:state])) - elseif s == :next_action - select_last_dim(b[:action], 2:nframes(b[:action])) - else - throw(ArgumentError("unknown trace name: $s")) - end -end - -Base.length(b::CompactSARTSATrajectory) = length(b[:terminal]) -Base.isempty(b::CompactSARTSATrajectory) = all(isempty(b[s]) for s in RTSA) - -function Base.getindex(b::CompactSARTSATrajectory, i::Int) - ( - state = select_last_dim(b[:state], i), - action = select_last_dim(b[:action], i), - reward = select_last_dim(b[:reward], i), - terminal = select_last_dim(b[:terminal], i), - next_state = select_last_dim(b[:state], i + 1), - next_action = select_last_dim(b[:action], i + 1), - ) -end - -function Base.empty!(b::CompactSARTSATrajectory) - for s in RTSA - empty!(b[s]) - end - b -end - -function Base.push!(b::CompactSARTSATrajectory, kv::Pair{Symbol}) - k, v = kv - if k == :state || k == :next_state - push!(b[:state], v) - elseif k == :action || k == :next_action - push!(b[:action], v) - elseif k == :reward || k == :terminal - push!(b[k], v) - else - throw(ArgumentError("unknown trace name: $k")) - end - b -end - -function Base.pop!(t::CompactSARTSATrajectory, s::Symbol) - if s == :state || s == :next_state - pop!(t[:state]) - elseif s == :action || s == :next_action - pop!(t[:action]) - elseif s == :reward || s == :terminal - pop!(t[s]) - else - throw(ArgumentError("unknown trace name: $s")) - end - t -end - -function Base.pop!(t::CompactSARTSATrajectory) - if length(t) <= 0 - throw(ArgumentError("can not pop! from an empty trajectory")) - else - NamedTuple{RTSA}(( - pop!(t, :reward), - pop!(t, :terminal), - pop!(t, :state), - pop!(t, :action), - )) - end - t -end diff --git a/src/components/trajectories/dummy_trajectory.jl b/src/components/trajectories/dummy_trajectory.jl deleted file mode 100644 index d67f770..0000000 --- a/src/components/trajectories/dummy_trajectory.jl +++ /dev/null @@ -1,5 +0,0 @@ -export DummyTrajectory - -struct DummyTrajectory <: AbstractTrajectory{(),Tuple{}} end - -Base.length(t::DummyTrajectory) = 0 diff --git a/src/components/trajectories/episodic_compact_SARTSA_buffer.jl b/src/components/trajectories/episodic_compact_SARTSA_buffer.jl deleted file mode 100644 index b9d0771..0000000 --- a/src/components/trajectories/episodic_compact_SARTSA_buffer.jl +++ /dev/null @@ -1,30 +0,0 @@ -export EpisodicCompactSARTSATrajectory - -using MacroTools: @forward - -""" - EpisodicCompactSARTSATrajectory(; state_type = Int, action_type = Int, reward_type = Float32, terminal_type = Bool) - -Exactly the same with [`VectorialCompactSARTSATrajectory`](@ref). It only exists for multiple dispatch purpose. - -!!! warning - The `EpisodicCompactSARTSATrajectory` will not be automatically emptified when reaching the end of an episode. -""" -struct EpisodicCompactSARTSATrajectory{types,trace_types} <: - AbstractTrajectory{SARTSA,types} - trajectories::VectorialCompactSARTSATrajectory{types,trace_types} -end - -EpisodicCompactSARTSATrajectory(; kwargs...) = - EpisodicCompactSARTSATrajectory(VectorialCompactSARTSATrajectory(; kwargs...)) - -@forward EpisodicCompactSARTSATrajectory.trajectories Base.length, -Base.isempty, -Base.empty!, -Base.push!, -Base.pop! - -# avoid method ambiguous -get_trace(t::EpisodicCompactSARTSATrajectory, s::Symbol) = get_trace(t.trajectories, s) -Base.getindex(t::EpisodicCompactSARTSATrajectory, i::Int) = getindex(t.trajectories, i) -Base.pop!(t::EpisodicCompactSARTSATrajectory, s::Symbol...) = pop!(t.trajectories, s...) diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl deleted file mode 100644 index a64e21d..0000000 --- a/src/components/trajectories/trajectories.jl +++ /dev/null @@ -1,10 +0,0 @@ -include("abstract_trajectory.jl") -include("dummy_trajectory.jl") -include("trajectory.jl") -include("vectorial_trajectory.jl") -include("circular_trajectory.jl") -include("vectorial_compact_SARTSA_buffer.jl") -include("circular_compact_SARTSA_buffer.jl") -include("episodic_compact_SARTSA_buffer.jl") -include("common.jl") -include("circular_compact_PSARTSA_buffer.jl") diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl deleted file mode 100644 index 0282835..0000000 --- a/src/components/trajectories/trajectory.jl +++ /dev/null @@ -1,23 +0,0 @@ -export Trajectory - -""" - Trajectory{names,types,Tbs}(trajectories::Tbs) - -A container of different `trajectories`. -Usually you won't use it directly. -""" -struct Trajectory{names,types,Tbs} <: AbstractTrajectory{names,types} - trajectories::Tbs -end - -"A helper function to access inner fields" -Base.getindex(t::Trajectory, s::Symbol) = getproperty(t.trajectories, s) - -get_trace(t::Trajectory, s::Symbol) = t[s] - -function Base.push!(t::Trajectory, kv::Pair{Symbol}) - k, v = kv - push!(t[k], v) -end - -Base.pop!(t::Trajectory, s::Symbol) = pop!(t[s]) diff --git a/src/components/trajectories/vectorial_compact_SARTSA_buffer.jl b/src/components/trajectories/vectorial_compact_SARTSA_buffer.jl deleted file mode 100644 index a731790..0000000 --- a/src/components/trajectories/vectorial_compact_SARTSA_buffer.jl +++ /dev/null @@ -1,66 +0,0 @@ -export VectorialCompactSARTSATrajectory - -const VectorialCompactSARTSATrajectory = Trajectory{ - SARTSA, - types, - NamedTuple{RTSA,trace_types}, -} where {types,trace_types<:Tuple{Vararg{<:Vector}}} - -""" - VectorialCompactSARTSATrajectory(; state_type = Int, action_type = Int, reward_type = Float32, terminal_type = Bool) - -This function creates a [`VectorialTrajectory`](@ref) of [`RTSA`](@ref) fields. Here the **Compact** in the function name means that, `state` and `next_state`, `action` and `next_action` reuse a same vector underlying. - -# Example - -```julia-repl -julia> t = VectorialCompactSARTSATrajectory() -0-element Trajectory{(:state, :action, :reward, :terminal, :next_state, :next_action),Tuple{Int64,Int64,Float32,Bool,Int64,Int64},NamedTuple{(:reward, :terminal, :state, :action),Tuple{Array{Float32,1},Array{Bool,1},Array{Int64,1},Array{Int64,1}}}} - -julia> push!(t, state=0, action=0) - -julia> push!(t, reward=0.f0, terminal=false, state=1, action=1) - -julia> t -1-element Trajectory{(:state, :action, :reward, :terminal, :next_state, :next_action),Tuple{Int64,Int64,Float32,Bool,Int64,Int64},NamedTuple{(:reward, :terminal, :state, :action),Tuple{Array{Float32,1},Array{Bool,1},Array{Int64,1},Array{Int64,1}}}}: - (state = 0, action = 0, reward = 0.0, terminal = 0, next_state = 1, next_action = 1) - -julia> push!(t, reward=1.f0, terminal=true, state=2, action=2) - -julia> t -2-element Trajectory{(:state, :action, :reward, :terminal, :next_state, :next_action),Tuple{Int64,Int64,Float32,Bool,Int64,Int64},NamedTuple{(:reward, :terminal, :state, :action),Tuple{Array{Float32,1},Array{Bool,1},Array{Int64,1},Array{Int64,1}}}}: - (state = 0, action = 0, reward = 0.0, terminal = 0, next_state = 1, next_action = 1) - (state = 1, action = 1, reward = 1.0, terminal = 1, next_state = 2, next_action = 2) - -julia> get_trace(t, :state, :action) -(state = [0, 1], action = [0, 1]) - -julia> get_trace(t, :next_state, :next_action) -(next_state = [1, 2], next_action = [1, 2]) - -julia> pop!(t) -1-element Trajectory{(:state, :action, :reward, :terminal, :next_state, :next_action),Tuple{Int64,Int64,Float32,Bool,Int64,Int64},NamedTuple{(:reward, :terminal, :state, :action),Tuple{Array{Float32,1},Array{Bool,1},Array{Int64,1},Array{Int64,1}}}}: - (state = 0, action = 0, reward = 0.0, terminal = 0, next_state = 1, next_action = 1) -``` -""" -function VectorialCompactSARTSATrajectory(; - state_type = Int, - action_type = Int, - reward_type = Float32, - terminal_type = Bool, -) - VectorialCompactSARTSATrajectory{ - Tuple{state_type,action_type,reward_type,terminal_type,state_type,action_type}, - Tuple{ - Vector{reward_type}, - Vector{terminal_type}, - Vector{state_type}, - Vector{action_type}, - }, - }(( - reward = Vector{reward_type}(), - terminal = Vector{terminal_type}(), - state = Vector{state_type}(), - action = Vector{action_type}(), - )) -end diff --git a/src/components/trajectories/vectorial_trajectory.jl b/src/components/trajectories/vectorial_trajectory.jl deleted file mode 100644 index 1a568dc..0000000 --- a/src/components/trajectories/vectorial_trajectory.jl +++ /dev/null @@ -1,47 +0,0 @@ -export VectorialTrajectory - -const VectorialTrajectory = Trajectory{ - names, - types, - NamedTuple{names,trace_types}, -} where {names,types,trace_types<:Tuple{Vararg{<:Vector}}} - -""" - VectorialTrajectory(;trace_name=trace_type ...) - -Use `Vector` to store the traces. - -# Example - -```julia-repl -julia> t = VectorialTrajectory(;a=Int, b=Symbol) -0-element Trajectory{(:a, :b),Tuple{Int64,Symbol},NamedTuple{(:a, :b),Tuple{Array{Int64,1},Array{Symbol,1}}}} - -julia> push!(t, a=0, b=:x) - -julia> push!(t, a=1, b=:y) - -julia> t -2-element Trajectory{(:a, :b),Tuple{Int64,Symbol},NamedTuple{(:a, :b),Tuple{Array{Int64,1},Array{Symbol,1}}}}: - (a = 0, b = :x) - (a = 1, b = :y) - -julia> get_trace(t, :b) -2-element Array{Symbol,1}: - :x - :y - -julia> pop!(t) - -julia> t -1-element Trajectory{(:a, :b),Tuple{Int64,Symbol},NamedTuple{(:a, :b),Tuple{Array{Int64,1},Array{Symbol,1}}}}: - (a = 0, b = :x) -``` -""" -function VectorialTrajectory(; kwargs...) - names = keys(kwargs.data) - types = values(kwargs.data) - trajectories = - merge(NamedTuple(), (name, Vector{type}()) for (name, type) in zip(names, types)) - VectorialTrajectory{names,Tuple{types...},typeof(values(trajectories))}(trajectories) -end diff --git a/src/utils/circular_array_buffer.jl b/src/utils/circular_array_buffer.jl index 5df7336..d0f3c7c 100644 --- a/src/utils/circular_array_buffer.jl +++ b/src/utils/circular_array_buffer.jl @@ -1,5 +1,7 @@ export CircularArrayBuffer, capacity, isfull +using ReinforcementLearningBase + """ CircularArrayBuffer{T}(d::Integer...) -> CircularArrayBuffer{T, N} diff --git a/src/utils/sum_tree.jl b/src/utils/sum_tree.jl index 744d8c1..4aa1674 100644 --- a/src/utils/sum_tree.jl +++ b/src/utils/sum_tree.jl @@ -63,7 +63,7 @@ mutable struct SumTree{T} <: AbstractArray{Int,1} length::Int nparents::Int tree::Vector{T} - SumTree(capacity::Int) = SumTree(Float64, capacity) + SumTree(capacity::Int) = SumTree(Float32, capacity) function SumTree(T, capacity) nparents = 2^ceil(Int, log2(capacity)) - 1 new{T}(capacity, 1, 0, nparents, zeros(T, nparents + capacity)) diff --git a/test/components/traces.jl b/test/components/traces.jl new file mode 100644 index 0000000..058b8d2 --- /dev/null +++ b/test/components/traces.jl @@ -0,0 +1,129 @@ +@testset "traces" begin + @testset "Trace" begin + t = Trace(;state=Vector{Int}(), reward=Vector{Bool}()) + @test (:state, :reward) == keys(t) + @test haskey(t, :state) + @test haskey(t, :reward) + push!(t; state=3, reward=true) + push!(t; state=4, reward=false) + @test t[:state] == [3,4] + @test t[:reward] == [true,false] + pop!(t) + @test t[:state] == [3] + @test t[:reward] == [true] + empty!(t) + @test t[:state] == Int[] + @test t[:reward] == Bool[] + end + + @testset "SharedTrace" begin + t = SharedTrace(Int[], :state) + @test (:state, :next_state, :full_state) == keys(t) + @test haskey(t, :state) + @test haskey(t, :next_state) + @test haskey(t, :full_state) + @test t[:state] == Int[] + @test t[:next_state] == Int[] + @test t[:full_state] == Int[] + push!(t;state=1,next_state=2) + @test t[:state] == [1] + @test t[:next_state] == [2] + @test t[:full_state] == [1, 2] + empty!(t) + @test t[:state] == Int[] + @test t[:next_state] == Int[] + @test t[:full_state] == Int[] + end + + @testset "EpisodicTrace" begin + t = EpisodicTrace( + Trace(;state=Vector{Int}(), reward=Vector{Bool}()), + :reward + ) + + @test isfull(t) == false + + @test (:state, :reward) == keys(t) + @test haskey(t, :state) + @test haskey(t, :reward) + push!(t; state=3, reward=true) + + @test isfull(t) == true + + push!(t; state=4, reward=false) + @test t[:state] == [3,4] + @test t[:reward] == [true,false] + pop!(t) + @test t[:state] == [3] + @test t[:reward] == [true] + empty!(t) + @test t[:state] == Int[] + @test t[:reward] == Bool[] + end + + @testset "CombinedTrace" begin + t = CircularCompactPSALRTSALTrace(;capacity=3, legal_actions_mask_size=(2,)) + push!(t; state=1, action=1, legal_actions_mask=[false, false]) + push!(t;reward=0.f0, terminal=false, priority=100, state=2, action=2, legal_actions_mask=[false, true]) + + @test t[:state] == [1] + @test t[:action] == [1] + @test t[:legal_actions_mask] == [false false]' + @test t[:reward] == [0.f0] + @test t[:terminal] == [false] + @test t[:priority] == [100] + @test t[:next_state] == [2] + @test t[:next_action] == [2] + @test t[:next_legal_actions_mask] == [false true]' + @test t[:full_state] == [1,2] + @test t[:full_action] == [1,2] + @test t[:full_legal_actions_mask] == [false false + false true] + + push!(t;reward=1.f0, terminal=true, priority=200, state=3, action=3, legal_actions_mask=[true, true]) + + @test t[:state] == [1,2] + @test t[:action] == [1,2] + @test t[:legal_actions_mask] == [false false + false true] + @test t[:reward] == [0.f0, 1.f0] + @test t[:terminal] == [false, true] + @test t[:priority] == [100,200] + @test t[:next_state] == [2, 3] + @test t[:next_action] == [2,3] + @test t[:next_legal_actions_mask] == [false true + true true] + @test t[:full_state] == [1,2,3] + @test t[:full_action] == [1,2,3] + @test t[:full_legal_actions_mask] == [false false true + false true true] + + pop!(t) + + @test t[:state] == [1] + @test t[:action] == [1] + @test t[:legal_actions_mask] == [false false]' + @test t[:reward] == [0.f0] + @test t[:terminal] == [false] + @test t[:priority] == [100] + @test t[:next_state] == [2] + @test t[:next_action] == [2] + @test t[:next_legal_actions_mask] == [false true]' + @test t[:full_state] == [1,2] + @test t[:full_action] == [1,2] + @test t[:full_legal_actions_mask] == [false false + false true] + + + empty!(t) + + @test t[:state] == [] + @test t[:action] == [] + @test t[:reward] == [] + @test t[:terminal] == [] + @test t[:next_state] == [] + @test t[:next_action] == [] + @test t[:full_state] == [] + @test t[:full_action] == [] + end +end diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl deleted file mode 100644 index dbd13b6..0000000 --- a/test/components/trajectories.jl +++ /dev/null @@ -1,241 +0,0 @@ -@testset "trajectories" begin - - @testset "VectorialTrajectory" begin - b = VectorialTrajectory(; state = Vector{Int}, reward = Float64) - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == true - - t1 = (state = [1, 2], reward = 0.0) - push!(b; t1...) - - @test length(b) == 1 - @test size(b) == (1,) - @test b[1] == b[end] == t1 - @test isempty(b) == false - @test get_trace(b, :state) == [t1.state] - @test get_trace(b, :reward) == [t1.reward] - - t2 = (state = [3, 4], reward = -1) - push!(b; t2...) - - @test length(b) == 2 - @test size(b) == (2,) - @test b[2] == b[end] == t2 - @test isempty(b) == false - @test get_trace(b, :state) == [t1.state, t2.state] - @test get_trace(b, :reward) == [t1.reward, t2.reward] - - empty!(b) - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == true - end - - @testset "VectorialCompactSARTSATrajectory" begin - b = VectorialCompactSARTSATrajectory() - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == true - - t1 = (state = 1, action = 2) - push!(b; t1...) - t2 = (reward = 1.0, terminal = false, state = 2, action = 3) - push!(b; t2...) - - @test length(b) == 1 - end - - @testset "EpisodicCompactSARTSATrajectory" begin - b = EpisodicCompactSARTSATrajectory() - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == true - - t1 = (state = 1, action = 2) - push!(b; t1...) - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == false - @test get_trace(b, :state) == [1] - @test get_trace(b, :action) == [2] - @test get_trace(b, :reward) == [] - @test get_trace(b, :terminal) == [] - @test get_trace(b, :next_state) == [] - @test get_trace(b, :next_action) == [] - - t2 = (reward = 1.0, terminal = false, state = 2, action = 3) - push!(b; t2...) - - @test length(b) == 1 - @test size(b) == (1,) - @test isempty(b) == false - @test get_trace(b, :state) == [1] - @test get_trace(b, :action) == [2] - @test get_trace(b, :reward) == [1.0] - @test get_trace(b, :terminal) == [false] - @test get_trace(b, :next_state) == [2] - @test get_trace(b, :next_action) == [3] - @test b[1] == - b[end] == - ( - state = 1, - action = 2, - reward = 1.0f0, - terminal = false, - next_state = 2, - next_action = 3, - ) - - t3 = (reward = 2.0, terminal = true, state = 3, action = 4) - push!(b; t3...) - - @test length(b) == 2 - @test size(b) == (2,) - @test isempty(b) == false - @test get_trace(b, :state) == [1, 2] - @test get_trace(b, :action) == [2, 3] - @test get_trace(b, :reward) == [1.0, 2.0] - @test get_trace(b, :terminal) == [false, true] - @test get_trace(b, :next_state) == [2, 3] - @test get_trace(b, :next_action) == [3, 4] - @test b[end] == ( - state = 2, - action = 3, - reward = 2.0f0, - terminal = true, - next_state = 3, - next_action = 4, - ) - - pop!(b) - - @test length(b) == 1 - @test size(b) == (1,) - @test isempty(b) == false - @test get_trace(b, :state) == [1] - @test get_trace(b, :action) == [2] - @test get_trace(b, :reward) == [1.0] - @test get_trace(b, :terminal) == [false] - @test get_trace(b, :next_state) == [2] - @test get_trace(b, :next_action) == [3] - end - - @testset "CircularCompactSARTSATrajectory" begin - b = CircularCompactSARTSATrajectory(; capacity = 3) - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == true - - t1 = (state = 1, action = 2) - push!(b; t1...) - - @test length(b) == 0 - @test size(b) == (0,) - @test isempty(b) == false - @test get_trace(b, :state) == [1] - @test get_trace(b, :action) == [2] - @test get_trace(b, :reward) == [] - @test get_trace(b, :terminal) == [] - @test get_trace(b, :next_state) == [] - @test get_trace(b, :next_action) == [] - - t2 = (reward = 1.0, terminal = false, state = 2, action = 3) - push!(b; t2...) - - @test length(b) == 1 - @test size(b) == (1,) - @test isempty(b) == false - @test get_trace(b, :state) == [1] - @test get_trace(b, :action) == [2] - @test get_trace(b, :reward) == [1.0] - @test get_trace(b, :terminal) == [false] - @test get_trace(b, :next_state) == [2] - @test get_trace(b, :next_action) == [3] - @test b[1] == - b[end] == - ( - state = 1, - action = 2, - reward = 1.0f0, - terminal = false, - next_state = 2, - next_action = 3, - ) - - t3 = (reward = 2.0, terminal = true, state = 3, action = 4) - push!(b; t3...) - - @test length(b) == 2 - @test size(b) == (2,) - @test isempty(b) == false - @test get_trace(b, :state) == [1, 2] - @test get_trace(b, :action) == [2, 3] - @test get_trace(b, :reward) == [1.0, 2.0] - @test get_trace(b, :terminal) == [false, true] - @test get_trace(b, :next_state) == [2, 3] - @test get_trace(b, :next_action) == [3, 4] - @test b[end] == ( - state = 2, - action = 3, - reward = 2.0f0, - terminal = true, - next_state = 3, - next_action = 4, - ) - - pop!(b, :state, :action) - push!(b, state = 4, action = 5) - - @test b[end] == ( - state = 2, - action = 3, - reward = 2.0f0, - terminal = true, - next_state = 4, - next_action = 5, - ) - - pop!(b) - - @test b[1] == - b[end] == - ( - state = 1, - action = 2, - reward = 1.0f0, - terminal = false, - next_state = 2, - next_action = 3, - ) - @test length(b) == 1 - - t4 = (reward = 3.0, terminal = false, state = 4, action = 5) - push!(b; t4...) - - @test length(b) == 2 - @test size(b) == (2,) - @test isempty(b) == false - @test get_trace(b, :state) == [1, 2] - @test get_trace(b, :action) == [2, 3] - @test get_trace(b, :reward) == [1.0, 3.0] - @test get_trace(b, :terminal) == [false, false] - @test get_trace(b, :next_state) == [2, 4] - @test get_trace(b, :next_action) == [3, 5] - @test b[end] == ( - state = 2, - action = 3, - reward = 3.0f0, - terminal = false, - next_state = 4, - next_action = 5, - ) - end - -end From 4425a6dda0157c7d560b8911583ca87e9343c229 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 11 Aug 2020 17:18:13 +0800 Subject: [PATCH 2/9] remove get_trace name --- src/components/agents/agent.jl | 139 ++------- src/components/traces/abstract_trace.jl | 46 --- src/components/traces/trace.jl | 291 ----------------- src/components/traces/traces.jl | 2 - .../trajectories/abstract_trajectory.jl | 46 +++ src/components/trajectories/trajectories.jl | 2 + src/components/trajectories/trajectory.jl | 294 ++++++++++++++++++ test/components/traces.jl | 18 +- 8 files changed, 375 insertions(+), 463 deletions(-) delete mode 100644 src/components/traces/abstract_trace.jl delete mode 100644 src/components/traces/trace.jl delete mode 100644 src/components/traces/traces.jl create mode 100644 src/components/trajectories/abstract_trajectory.jl create mode 100644 src/components/trajectories/trajectories.jl create mode 100644 src/components/trajectories/trajectory.jl diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index 11373f3..a7b2b65 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -20,7 +20,7 @@ Generally speaking, it does nothing but update the trajectory and policy appropr """ Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent policy::P - trajectory::T = DummyTrajectory() + trajectory::T = DUMMY_TRAJECTORY role::R = RLBase.DEFAULT_PLAYER is_training::Bool = true end @@ -84,144 +84,53 @@ end agent.policy(env) ##### -# EpisodicCompactSARTSATrajectory +# default behavior ##### -function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( - ::Training{PreEpisodeStage}, - env, -) - empty!(agent.trajectory) - nothing -end -function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( - ::Training{PreActStage}, - env, -) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - update!(agent.policy, agent.trajectory) - action -end - -function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( - ::Training{PostActStage}, - env, -) - push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) - nothing -end - -function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( - ::Training{PostEpisodeStage}, - env, -) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - update!(agent.policy, agent.trajectory) - action -end - -##### -# Union{CircularCompactSARTSATrajectory, CircularCompactPSARTSATrajectory} -##### - -function ( - agent::Agent{ - <:AbstractPolicy, - <:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory}, - } -)( - ::Training{PreEpisodeStage}, - env, -) - if length(agent.trajectory) > 0 - pop!(agent.trajectory, :state, :action) +function (agent::Agent)(::Training{PreEpisodeStage}, env) + if nframes(agent.trajectory[:full_state]) > 0 + pop!(agent.trajectory, :full_state) + end + if nframes(agent.trajectory[:full_action]) + pop!(agent.trajectory, :full_action) + end + if ActionStyle(env) === FULL_ACTION_SET && nframes(agent.trajectory[:full_legal_actions_mask]) + pop!(agent.trajectory, :full_legal_actions_mask) end - nothing end -function ( - agent::Agent{ - <:AbstractPolicy, - <:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory}, - } -)( - ::Training{PreActStage}, - env, -) +function (agent::Agent)(::Training{PreActStage}, env) action = agent.policy(env) push!(agent.trajectory; state = get_state(env), action = action) + if ActionStyle(env) === FULL_ACTION_SET + push!(agent.trajectory; legal_actions_mask=get_legal_actions_mask(env)) + end update!(agent.policy, agent.trajectory) action end -function ( - agent::Agent{ - <:AbstractPolicy, - <:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory}, - } -)( - ::Training{PostActStage}, - env, -) +function (agent::Agent)(::Training{PostActStage}, env) push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) nothing end -function ( - agent::Agent{ - <:AbstractPolicy, - <:Union{CircularCompactSARTSATrajectory,CircularCompactPSARTSATrajectory}, - } -)( - ::Training{PostEpisodeStage}, - env, -) +function (agent::Agent)(::Training{PostEpisodeStage}, env) action = agent.policy(env) push!(agent.trajectory; state = get_state(env), action = action) + if ActionStyle(env) === FULL_ACTION_SET + push!(agent.trajectory; legal_actions_mask=get_legal_actions_mask(env)) + end update!(agent.policy, agent.trajectory) action end ##### -# VectorialCompactSARTSATrajectory +# EpisodicCompactSARTSATrajectory ##### - -function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( +function (agent::Agent{<:AbstractPolicy,<:EpisodicTrajectory})( ::Training{PreEpisodeStage}, env, ) - if length(agent.trajectory) > 0 - pop!(agent.trajectory, :state, :action) - end - nothing -end - -function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( - ::Training{PreActStage}, - env, -) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - update!(agent.policy, agent.trajectory) - action -end - -function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( - ::Training{PostActStage}, - env, -) - push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) + empty!(agent.trajectory) nothing -end - -function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( - ::Training{PostEpisodeStage}, - env, -) - action = agent.policy(env) - push!(agent.trajectory; state = get_state(env), action = action) - update!(agent.policy, agent.trajectory) - action -end +end \ No newline at end of file diff --git a/src/components/traces/abstract_trace.jl b/src/components/traces/abstract_trace.jl deleted file mode 100644 index 2d05faa..0000000 --- a/src/components/traces/abstract_trace.jl +++ /dev/null @@ -1,46 +0,0 @@ -export AbstractTrace - -""" - AbstractTrace - -A trace is used to record some useful information -during the interactions between agents and environments. - -Required Methods: - -- `Base.haskey(t::AbstractTrace, s::Symbol)` -- `Base.getproperty(t::AbstractTrace, s::Symbol)` -- `Base.keys(t::AbstractTrace)` -- `Base.push!(t::AbstractTrace, kv::Pair{Symbol})` -- `Base.pop!(t::AbstractTrace, s::Symbol)` - -Optional Methods: - -- `isfull` -- `empty!` - -""" -abstract type AbstractTrace end - -function Base.push!(t::AbstractTrace;kwargs...) - for kv in kwargs - push!(t, kv) - end -end - -""" - Base.pop!(t::AbstractTrace, s::Symbol...) - -`pop!` out one element of the traces specified in `s` -""" -function Base.pop!(t::AbstractTrace, s::Tuple{Vararg{Symbol}}) - NamedTuple{s}(pop!(t, x) for x in s) -end - -Base.pop!(t::AbstractTrace) = pop!(t, keys(t)) - -function Base.empty!(t::AbstractTrace) - for s in keys(t) - empty!(t[s]) - end -end \ No newline at end of file diff --git a/src/components/traces/trace.jl b/src/components/traces/trace.jl deleted file mode 100644 index c651697..0000000 --- a/src/components/traces/trace.jl +++ /dev/null @@ -1,291 +0,0 @@ -using MacroTools:@forward - -##### -# Trace -##### - -""" - Trace(;[trace_name=trace_container]...) - -Simply a wrapper of `NamedTuple`. -Define our own type here to avoid type piracy with `NamedTuple` -""" -struct Trace{T} <: AbstractTrace - traces::T -end - -Trace(;kwargs...) = Trace(kwargs.data) - -@forward Trace.traces Base.keys, Base.haskey, Base.getindex - -Base.push!(t::Trace, kv::Pair{Symbol}) = push!(t[first(kv)], last(kv)) -Base.pop!(t::Trace, s::Symbol) = pop!(t[s]) - -##### -# SharedTrace -##### - -struct SharedTraceMeta - start_shift::Int - end_shift::Int -end - -""" - SharedTrace(trace;[trace_name=start_shift:end_shift]...) - -Create multiple traces sharing the same underlying container. -""" -struct SharedTrace{X,M} <: AbstractTrace - x::X - meta::M -end - -function SharedTrace(x, s::Symbol) - SharedTrace( - x, - (; - s=>SharedTraceMeta(1, -1), - Symbol(:next_, s)=>SharedTraceMeta(2, 0), - Symbol(:full_, s) => SharedTraceMeta(1,0) - ) - ) -end - -@forward SharedTrace.meta Base.keys, Base.haskey - -function Base.getindex(t::SharedTrace, s::Symbol) - m = t.meta[s] - select_last_dim(t.x, m.start_shift:(nframes(t.x)+m.end_shift)) -end - -Base.push!(t::SharedTrace, kv::Pair{Symbol}) = push!(t.x, last(kv)) -Base.empty!(t::SharedTrace) = empty!(t.x) -Base.pop!(t::SharedTrace, s::Symbol) = pop!(t.x) - -function Base.pop!(t::SharedTrace) - s = first(keys(t)) - (;s => pop!(t.x)) -end - -##### -# EpisodicTrace -##### - -""" -Assuming that the `flag_trace` is in `traces` and it's an `AbstractVector{Bool}`, -meaning whether an environment reaches terminal or not. The last element in -`flag_trace` will be used to determine whether the whole trace is full or not. -""" -struct EpisodicTrace{T, flag_trace} <: AbstractTrace - traces::T -end - -EpisodicTrace(traces::T, flag_trace=:terminal) where T = EpisodicTrace{T, flag_trace}(traces) - -@forward EpisodicTrace.traces Base.keys, Base.haskey, Base.getindex, Base.push!, Base.pop!, Base.empty! - -function isfull(t::EpisodicTrace{<:Any, F}) where F - x = t.traces[F] - (nframes(x) > 0) && select_last_frame(x) -end - -##### -# CombinedTrace -##### - -struct CombinedTrace{T1, T2} <: AbstractTrace - t1::T1 - t2::T2 -end - -Base.haskey(t::CombinedTrace, s::Symbol) = haskey(t.t1, s) || haskey(t.t2, s) -Base.getindex(t::CombinedTrace, s::Symbol) = if haskey(t.t1, s) - getindex(t.t1, s) -elseif haskey(t.t2, s) - getindex(t.t2, s) -else - throw(ArgumentError("unknown key: $s")) -end - -Base.keys(t::CombinedTrace) = (keys(t.t1)..., keys(t.t2)...) - -Base.push!(t::CombinedTrace, kv::Pair{Symbol}) = if haskey(t.t1, first(kv)) - push!(t.t1, kv) -elseif haskey(t.t2, first(kv)) - push!(t.t2, kv) -else - throw(ArgumentError("unknown kv: $kv")) -end - -Base.pop!(t::CombinedTrace, s::Symbol) = if haskey(t.t1, s) - pop!(t.t1, s) -elseif haskey(t.t2, s) - pop!(t.t2, s) -else - throw(ArgumentError("unknown key: $s")) -end - -Base.pop!(t::CombinedTrace) = merge(pop!(t.t1), pop!(t.t2)) - -function Base.empty!(t::CombinedTrace) - empty!(t.t1) - empty!(t.t2) -end - -##### -# CircularCompactSATrace -##### - -const CircularCompactSATrace = CombinedTrace{ - <:SharedTrace{<:CircularArrayBuffer, <:NamedTuple{(:state, :next_state, :full_state)}}, - <:SharedTrace{<:CircularArrayBuffer, <:NamedTuple{(:action, :next_action, :full_action)}}, -} - -function CircularCompactSATrace(; - capacity, - state_type = Int, - state_size = (), - action_type = Int, - action_size = (), -) - CombinedTrace( - SharedTrace( - CircularArrayBuffer{state_type}(state_size..., capacity+1), - :state), - SharedTrace( - CircularArrayBuffer{action_type}(action_size..., capacity+1), - :action - ), - ) -end - -##### -# CircularCompactSALTrace -##### - -const CircularCompactSALTrace = CombinedTrace{ - <:SharedTrace{<:CircularArrayBuffer, <:NamedTuple{(:legal_actions_mask, :next_legal_actions_mask, :full_legal_actions_mask)}}, - <:CircularCompactSATrace -} - -function CircularCompactSALTrace(; - capacity, - legal_actions_mask_size, - legal_actions_mask_type=Bool, - kw... -) - CombinedTrace( - SharedTrace( - CircularArrayBuffer{legal_actions_mask_type}(legal_actions_mask_size..., capacity+1), - :legal_actions_mask - ), - CircularCompactSATrace(;capacity=capacity, kw...) - ) -end -##### -# CircularCompactSARTSATrace -##### - -const CircularCompactSARTSATrace = CombinedTrace{ - <:Trace{<:NamedTuple{(:reward, :terminal), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer}}}, - <:CircularCompactSATrace -} - -function CircularCompactSARTSATrace(; - capacity, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw... -) - CombinedTrace( - Trace( - reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - ), - CircularCompactSATrace(;capacity=capacity, kw...), - ) -end - -##### -# CircularCompactSALRTSALTrace -##### - -const CircularCompactSALRTSALTrace = CombinedTrace{ - <:Trace{<:NamedTuple{(:reward, :terminal), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer}}}, - <:CircularCompactSALTrace -} - -function CircularCompactSALRTSALTrace(; - capacity, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw... - ) - CombinedTrace( - Trace( - reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - ), - CircularCompactSALTrace(;capacity=capacity, kw...), - ) -end - -##### -# CircularCompactPSARTSATrace -##### - -const CircularCompactPSARTSATrace = CombinedTrace{ - <:Trace{<:NamedTuple{(:reward, :terminal,:priority), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:SumTree}}}, - <:CircularCompactSATrace -} - -function CircularCompactPSARTSATrace(; - capacity, - priority_type=Float32, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw... -) - CombinedTrace( - Trace( - reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - priority=SumTree(priority_type, capacity) - ), - CircularCompactSATrace(;capacity=capacity, kw...), - ) -end - -##### -# CircularCompactPSALRTSALTrace -##### - -const CircularCompactPSALRTSALTrace = CombinedTrace{ - <:Trace{<:NamedTuple{(:reward, :terminal,:priority), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:SumTree}}}, - <:CircularCompactSALTrace -} - -function CircularCompactPSALRTSALTrace(; - capacity, - priority_type=Float32, - reward_type = Float32, - reward_size = (), - terminal_type = Bool, - terminal_size = (), - kw... -) - CombinedTrace( - Trace( - reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), - priority=SumTree(priority_type, capacity) - ), - CircularCompactSALTrace(;capacity=capacity, kw...), - ) -end \ No newline at end of file diff --git a/src/components/traces/traces.jl b/src/components/traces/traces.jl deleted file mode 100644 index 1bc7e12..0000000 --- a/src/components/traces/traces.jl +++ /dev/null @@ -1,2 +0,0 @@ -include("abstract_trace.jl") -include("trace.jl") \ No newline at end of file diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl new file mode 100644 index 0000000..079e3cd --- /dev/null +++ b/src/components/trajectories/abstract_trajectory.jl @@ -0,0 +1,46 @@ +export AbstractTrajectory + +""" + AbstractTrajectory + +A trace is used to record some useful information +during the interactions between agents and environments. + +Required Methods: + +- `Base.haskey(t::AbstractTrajectory, s::Symbol)` +- `Base.getindex(t::AbstractTrajectory, s::Symbol)` +- `Base.keys(t::AbstractTrajectory)` +- `Base.push!(t::AbstractTrajectory, kv::Pair{Symbol})` +- `Base.pop!(t::AbstractTrajectory, s::Symbol)` +- `Base.empty!(t::AbstractTrajectory)` + +Optional Methods: + +- `isfull` + +""" +abstract type AbstractTrajectory end + +function Base.push!(t::AbstractTrajectory;kwargs...) + for kv in kwargs + push!(t, kv) + end +end + +""" + Base.pop!(t::AbstractTrajectory, s::Symbol...) + +`pop!` out one element of the traces specified in `s` +""" +function Base.pop!(t::AbstractTrajectory, s::Tuple{Vararg{Symbol}}) + NamedTuple{s}(pop!(t, x) for x in s) +end + +Base.pop!(t::AbstractTrajectory) = pop!(t, keys(t)) + +function Base.empty!(t::AbstractTrajectory) + for s in keys(t) + empty!(t[s]) + end +end \ No newline at end of file diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl new file mode 100644 index 0000000..9e6c97c --- /dev/null +++ b/src/components/trajectories/trajectories.jl @@ -0,0 +1,2 @@ +include("abstract_trajectory.jl") +include("trajectory.jl") \ No newline at end of file diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl new file mode 100644 index 0000000..521c30e --- /dev/null +++ b/src/components/trajectories/trajectory.jl @@ -0,0 +1,294 @@ +using MacroTools:@forward + +##### +# Trajectory +##### + +""" + Trajectory(;[trace_name=trace_container]...) + +Simply a wrapper of `NamedTuple`. +Define our own type here to avoid type piracy with `NamedTuple` +""" +struct Trajectory{T} <: AbstractTrajectory + traces::T +end + +Trajectory(;kwargs...) = Trajectory(kwargs.data) + +const DummyTrajectory = Trajectory{NamedTuple{(),Tuple{}}} +const DUMMY_TRAJECTORY = Trajectory() + +@forward Trajectory.traces Base.keys, Base.haskey, Base.getindex + +Base.push!(t::Trajectory, kv::Pair{Symbol}) = push!(t[first(kv)], last(kv)) +Base.pop!(t::Trajectory, s::Symbol) = pop!(t[s]) + +##### +# SharedTrajectory +##### + +struct SharedTrajectoryMeta + start_shift::Int + end_shift::Int +end + +""" + SharedTrajectory(trace;[trace_name=start_shift:end_shift]...) + +Create multiple traces sharing the same underlying container. +""" +struct SharedTrajectory{X,M} <: AbstractTrajectory + x::X + meta::M +end + +function SharedTrajectory(x, s::Symbol) + SharedTrajectory( + x, + (; + s=>SharedTrajectoryMeta(1, -1), + Symbol(:next_, s)=>SharedTrajectoryMeta(2, 0), + Symbol(:full_, s) => SharedTrajectoryMeta(1,0) + ) + ) +end + +@forward SharedTrajectory.meta Base.keys, Base.haskey + +function Base.getindex(t::SharedTrajectory, s::Symbol) + m = t.meta[s] + select_last_dim(t.x, m.start_shift:(nframes(t.x)+m.end_shift)) +end + +Base.push!(t::SharedTrajectory, kv::Pair{Symbol}) = push!(t.x, last(kv)) +Base.empty!(t::SharedTrajectory) = empty!(t.x) +Base.pop!(t::SharedTrajectory, s::Symbol) = pop!(t.x) + +function Base.pop!(t::SharedTrajectory) + s = first(keys(t)) + (;s => pop!(t.x)) +end + +##### +# EpisodicTrajectory +##### + +""" +Assuming that the `flag_trace` is in `traces` and it's an `AbstractVector{Bool}`, +meaning whether an environment reaches terminal or not. The last element in +`flag_trace` will be used to determine whether the whole trace is full or not. +""" +struct EpisodicTrajectory{T, flag_trace} <: AbstractTrajectory + traces::T +end + +EpisodicTrajectory(traces::T, flag_trace=:terminal) where T = EpisodicTrajectory{T, flag_trace}(traces) + +@forward EpisodicTrajectory.traces Base.keys, Base.haskey, Base.getindex, Base.push!, Base.pop!, Base.empty! + +function isfull(t::EpisodicTrajectory{<:Any, F}) where F + x = t.traces[F] + (nframes(x) > 0) && select_last_frame(x) +end + +##### +# CombinedTrajectory +##### + +struct CombinedTrajectory{T1, T2} <: AbstractTrajectory + t1::T1 + t2::T2 +end + +Base.haskey(t::CombinedTrajectory, s::Symbol) = haskey(t.t1, s) || haskey(t.t2, s) +Base.getindex(t::CombinedTrajectory, s::Symbol) = if haskey(t.t1, s) + getindex(t.t1, s) +elseif haskey(t.t2, s) + getindex(t.t2, s) +else + throw(ArgumentError("unknown key: $s")) +end + +Base.keys(t::CombinedTrajectory) = (keys(t.t1)..., keys(t.t2)...) + +Base.push!(t::CombinedTrajectory, kv::Pair{Symbol}) = if haskey(t.t1, first(kv)) + push!(t.t1, kv) +elseif haskey(t.t2, first(kv)) + push!(t.t2, kv) +else + throw(ArgumentError("unknown kv: $kv")) +end + +Base.pop!(t::CombinedTrajectory, s::Symbol) = if haskey(t.t1, s) + pop!(t.t1, s) +elseif haskey(t.t2, s) + pop!(t.t2, s) +else + throw(ArgumentError("unknown key: $s")) +end + +Base.pop!(t::CombinedTrajectory) = merge(pop!(t.t1), pop!(t.t2)) + +function Base.empty!(t::CombinedTrajectory) + empty!(t.t1) + empty!(t.t2) +end + +##### +# CircularCompactSATrajectory +##### + +const CircularCompactSATrajectory = CombinedTrajectory{ + <:SharedTrajectory{<:CircularArrayBuffer, <:NamedTuple{(:state, :next_state, :full_state)}}, + <:SharedTrajectory{<:CircularArrayBuffer, <:NamedTuple{(:action, :next_action, :full_action)}}, +} + +function CircularCompactSATrajectory(; + capacity, + state_type = Int, + state_size = (), + action_type = Int, + action_size = (), +) + CombinedTrajectory( + SharedTrajectory( + CircularArrayBuffer{state_type}(state_size..., capacity+1), + :state), + SharedTrajectory( + CircularArrayBuffer{action_type}(action_size..., capacity+1), + :action + ), + ) +end + +##### +# CircularCompactSALTrajectory +##### + +const CircularCompactSALTrajectory = CombinedTrajectory{ + <:SharedTrajectory{<:CircularArrayBuffer, <:NamedTuple{(:legal_actions_mask, :next_legal_actions_mask, :full_legal_actions_mask)}}, + <:CircularCompactSATrajectory +} + +function CircularCompactSALTrajectory(; + capacity, + legal_actions_mask_size, + legal_actions_mask_type=Bool, + kw... +) + CombinedTrajectory( + SharedTrajectory( + CircularArrayBuffer{legal_actions_mask_type}(legal_actions_mask_size..., capacity+1), + :legal_actions_mask + ), + CircularCompactSATrajectory(;capacity=capacity, kw...) + ) +end +##### +# CircularCompactSARTSATrajectory +##### + +const CircularCompactSARTSATrajectory = CombinedTrajectory{ + <:Trajectory{<:NamedTuple{(:reward, :terminal), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer}}}, + <:CircularCompactSATrajectory +} + +function CircularCompactSARTSATrajectory(; + capacity, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... +) + CombinedTrajectory( + Trajectory( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + ), + CircularCompactSATrajectory(;capacity=capacity, kw...), + ) +end + +##### +# CircularCompactSALRTSALTrajectory +##### + +const CircularCompactSALRTSALTrajectory = CombinedTrajectory{ + <:Trajectory{<:NamedTuple{(:reward, :terminal), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer}}}, + <:CircularCompactSALTrajectory +} + +function CircularCompactSALRTSALTrajectory(; + capacity, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... + ) + CombinedTrajectory( + Trajectory( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + ), + CircularCompactSALTrajectory(;capacity=capacity, kw...), + ) +end + +##### +# CircularCompactPSARTSATrajectory +##### + +const CircularCompactPSARTSATrajectory = CombinedTrajectory{ + <:Trajectory{<:NamedTuple{(:reward, :terminal,:priority), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:SumTree}}}, + <:CircularCompactSATrajectory +} + +function CircularCompactPSARTSATrajectory(; + capacity, + priority_type=Float32, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... +) + CombinedTrajectory( + Trajectory( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + priority=SumTree(priority_type, capacity) + ), + CircularCompactSATrajectory(;capacity=capacity, kw...), + ) +end + +##### +# CircularCompactPSALRTSALTrajectory +##### + +const CircularCompactPSALRTSALTrajectory = CombinedTrajectory{ + <:Trajectory{<:NamedTuple{(:reward, :terminal,:priority), <:Tuple{<:CircularArrayBuffer, <:CircularArrayBuffer, <:SumTree}}}, + <:CircularCompactSALTrajectory +} + +function CircularCompactPSALRTSALTrajectory(; + capacity, + priority_type=Float32, + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw... +) + CombinedTrajectory( + Trajectory( + reward=CircularArrayBuffer{reward_type}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_type}(terminal_size..., capacity), + priority=SumTree(priority_type, capacity) + ), + CircularCompactSALTrajectory(;capacity=capacity, kw...), + ) +end \ No newline at end of file diff --git a/test/components/traces.jl b/test/components/traces.jl index 058b8d2..ee47790 100644 --- a/test/components/traces.jl +++ b/test/components/traces.jl @@ -1,6 +1,6 @@ @testset "traces" begin - @testset "Trace" begin - t = Trace(;state=Vector{Int}(), reward=Vector{Bool}()) + @testset "Trajectory" begin + t = Trajectory(;state=Vector{Int}(), reward=Vector{Bool}()) @test (:state, :reward) == keys(t) @test haskey(t, :state) @test haskey(t, :reward) @@ -16,8 +16,8 @@ @test t[:reward] == Bool[] end - @testset "SharedTrace" begin - t = SharedTrace(Int[], :state) + @testset "SharedTrajectory" begin + t = SharedTrajectory(Int[], :state) @test (:state, :next_state, :full_state) == keys(t) @test haskey(t, :state) @test haskey(t, :next_state) @@ -35,9 +35,9 @@ @test t[:full_state] == Int[] end - @testset "EpisodicTrace" begin - t = EpisodicTrace( - Trace(;state=Vector{Int}(), reward=Vector{Bool}()), + @testset "EpisodicTrajectory" begin + t = EpisodicTrajectory( + Trajectory(;state=Vector{Int}(), reward=Vector{Bool}()), :reward ) @@ -61,8 +61,8 @@ @test t[:reward] == Bool[] end - @testset "CombinedTrace" begin - t = CircularCompactPSALRTSALTrace(;capacity=3, legal_actions_mask_size=(2,)) + @testset "CombinedTrajectory" begin + t = CircularCompactPSALRTSALTrajectory(;capacity=3, legal_actions_mask_size=(2,)) push!(t; state=1, action=1, legal_actions_mask=[false, false]) push!(t;reward=0.f0, terminal=false, priority=100, state=2, action=2, legal_actions_mask=[false, true]) From 1d1656d927ce1b141c1e2f80b476bfe09b8ebca9 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 11 Aug 2020 17:58:50 +0800 Subject: [PATCH 3/9] fix tests --- src/components/agents/agent.jl | 2 +- src/components/agents/dyna_agent.jl | 8 ++++---- src/components/trajectories/trajectory.jl | 11 +++++++++++ test/components/agents.jl | 2 +- test/components/{traces.jl => trajectories.jl} | 0 test/core/core.jl | 2 +- 6 files changed, 18 insertions(+), 7 deletions(-) rename test/components/{traces.jl => trajectories.jl} (100%) diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index a7b2b65..6557647 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -91,7 +91,7 @@ function (agent::Agent)(::Training{PreEpisodeStage}, env) if nframes(agent.trajectory[:full_state]) > 0 pop!(agent.trajectory, :full_state) end - if nframes(agent.trajectory[:full_action]) + if nframes(agent.trajectory[:full_action]) > 0 pop!(agent.trajectory, :full_action) end if ActionStyle(env) === FULL_ACTION_SET && nframes(agent.trajectory[:full_legal_actions_mask]) diff --git a/src/components/agents/dyna_agent.jl b/src/components/agents/dyna_agent.jl index c0fd2d9..31e581d 100644 --- a/src/components/agents/dyna_agent.jl +++ b/src/components/agents/dyna_agent.jl @@ -33,7 +33,7 @@ end get_role(agent::DynaAgent) = agent.role -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( +function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})( ::PreEpisodeStage, env, ) @@ -41,7 +41,7 @@ function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( nothing end -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( +function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})( ::PreActStage, env, ) @@ -53,7 +53,7 @@ function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( action end -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( +function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})( ::PostActStage, env, ) @@ -61,7 +61,7 @@ function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( nothing end -function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( +function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicTrajectory})( ::PostEpisodeStage, env, ) diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 521c30e..b135e35 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -1,3 +1,14 @@ +export Trajectory, + SharedTrajectory, + EpisodicTrajectory, + CombinedTrajectory, + CircularCompactSATrajectory, + CircularCompactSALTrajectory, + CircularCompactSARTSATrajectory, + CircularCompactPSARTSATrajectory, + CircularCompactSALRTSALTrajectory, + CircularCompactPSALRTSALTrajectory + using MacroTools:@forward ##### diff --git a/test/components/agents.jl b/test/components/agents.jl index 876db55..8a28d29 100644 --- a/test/components/agents.jl +++ b/test/components/agents.jl @@ -2,7 +2,7 @@ action_space = DiscreteSpace(3) agent = Agent(; policy = RandomPolicy(action_space), - trajectory = VectorialCompactSARTSATrajectory(), + trajectory = CircularCompactSARTSATrajectory(;capacity=10_000, state_type = Float32, state_size=(4,)), ) @testset "loading/saving Agent" begin diff --git a/test/components/traces.jl b/test/components/trajectories.jl similarity index 100% rename from test/components/traces.jl rename to test/components/trajectories.jl diff --git a/test/core/core.jl b/test/core/core.jl index 91e8d5c..3de6e97 100644 --- a/test/core/core.jl +++ b/test/core/core.jl @@ -2,7 +2,7 @@ env = CartPoleEnv{Float32}() |> StateOverriddenEnv(deepcopy) agent = Agent(; policy = RandomPolicy(env), - trajectory = VectorialCompactSARTSATrajectory(; state_type = Vector{Float32}), + trajectory = CircularCompactSARTSATrajectory(;capacity=10_000, state_type = Float32, state_size=(4,)), ) N_EPISODE = 10000 hook = TotalRewardPerEpisode() From 696a0d1b56ee42bb63ab5087f05977bff85ed79f Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 11 Aug 2020 23:48:00 +0800 Subject: [PATCH 4/9] implement isfull --- src/components/trajectories/trajectory.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index b135e35..9e1a128 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -35,6 +35,8 @@ const DUMMY_TRAJECTORY = Trajectory() Base.push!(t::Trajectory, kv::Pair{Symbol}) = push!(t[first(kv)], last(kv)) Base.pop!(t::Trajectory, s::Symbol) = pop!(t[s]) +isfull(t::Trajectory) = all(isfull, t.traces) + ##### # SharedTrajectory ##### @@ -81,6 +83,8 @@ function Base.pop!(t::SharedTrajectory) (;s => pop!(t.x)) end +isfull(t::SharedTrajectory) = isfull(t.x) + ##### # EpisodicTrajectory ##### @@ -146,6 +150,8 @@ function Base.empty!(t::CombinedTrajectory) empty!(t.t2) end +isfull(t::CombinedTrajectory) = isfull(t.t1) && isfull(t.t2) + ##### # CircularCompactSATrajectory ##### From 5713c1f3496d32eea7c3f4102ce750dd8c778fee Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 11 Aug 2020 23:53:04 +0800 Subject: [PATCH 5/9] deprecate get_trace --- src/components/trajectories/abstract_trajectory.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl index 079e3cd..c4eb285 100644 --- a/src/components/trajectories/abstract_trajectory.jl +++ b/src/components/trajectories/abstract_trajectory.jl @@ -43,4 +43,6 @@ function Base.empty!(t::AbstractTrajectory) for s in keys(t) empty!(t[s]) end -end \ No newline at end of file +end + +@deprecate get_trace(t::AbstractTrajectory, s::Symbol) t[s] \ No newline at end of file From 7ea56ca5d9c7e751252f3734f23e213284ea1d1a Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 12 Aug 2020 10:36:27 +0800 Subject: [PATCH 6/9] resolve comments --- src/components/agents/agent.jl | 2 +- .../trajectories/abstract_trajectory.jl | 9 +++++++++ src/components/trajectories/trajectory.jl | 18 ++++++++++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index 6557647..63c780f 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -94,7 +94,7 @@ function (agent::Agent)(::Training{PreEpisodeStage}, env) if nframes(agent.trajectory[:full_action]) > 0 pop!(agent.trajectory, :full_action) end - if ActionStyle(env) === FULL_ACTION_SET && nframes(agent.trajectory[:full_legal_actions_mask]) + if ActionStyle(env) === FULL_ACTION_SET && nframes(agent.trajectory[:full_legal_actions_mask]) > 0 pop!(agent.trajectory, :full_legal_actions_mask) end end diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl index c4eb285..8a0b1c6 100644 --- a/src/components/trajectories/abstract_trajectory.jl +++ b/src/components/trajectories/abstract_trajectory.jl @@ -45,4 +45,13 @@ function Base.empty!(t::AbstractTrajectory) end end +##### +# patch code +##### + +# avoid showing the inner structure +function AbstractTrees.children(t::StructTree{<:AbstractTrajectory}) + Tuple(k => StructTree(t.x[k]) for k in keys(t.x)) +end + @deprecate get_trace(t::AbstractTrajectory, s::Symbol) t[s] \ No newline at end of file diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index 9e1a128..f495471 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -27,8 +27,8 @@ end Trajectory(;kwargs...) = Trajectory(kwargs.data) -const DummyTrajectory = Trajectory{NamedTuple{(),Tuple{}}} const DUMMY_TRAJECTORY = Trajectory() +const DummyTrajectory = typeof(DUMMY_TRAJECTORY) @forward Trajectory.traces Base.keys, Base.haskey, Base.getindex @@ -47,7 +47,7 @@ struct SharedTrajectoryMeta end """ - SharedTrajectory(trace;[trace_name=start_shift:end_shift]...) + SharedTrajectory(trace_container, meta::NamedTuple{([trace_name::Symbol],...), Tuple{[SharedTrajectoryMeta]...}}) Create multiple traces sharing the same underlying container. """ @@ -56,6 +56,15 @@ struct SharedTrajectory{X,M} <: AbstractTrajectory meta::M end +""" + SharedTrajectory(trace_container, s::Symbol) + +Automatically create the following three traces: + +- `s`, share the data in `trace_container` in the range of `1:end-1` +- `s` with a prefix of `next_`, share the data in `trace_container` in the range of `2:end` +- `s` with a prefix of `full_`, a view of `trace_container` +""" function SharedTrajectory(x, s::Symbol) SharedTrajectory( x, @@ -90,6 +99,8 @@ isfull(t::SharedTrajectory) = isfull(t.x) ##### """ + EpisodicTrajectory(traces::T, flag_trace=:terminal) + Assuming that the `flag_trace` is in `traces` and it's an `AbstractVector{Bool}`, meaning whether an environment reaches terminal or not. The last element in `flag_trace` will be used to determine whether the whole trace is full or not. @@ -111,6 +122,9 @@ end # CombinedTrajectory ##### +""" + CombinedTrajectory(t1::AbstractTrajectory, t2::AbstractTrajectory) +""" struct CombinedTrajectory{T1, T2} <: AbstractTrajectory t1::T1 t2::T2 From 34f1af765c0f511afb76407a51370ac9f49a827f Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 12 Aug 2020 10:53:01 +0800 Subject: [PATCH 7/9] do not show internals of environments --- src/extensions/ReinforcementLearningBase.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index 4535f11..3294940 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -19,6 +19,8 @@ end Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10)) +AbstractTrees.children(t::StructTree{<:AbstractEnv}) = () + AbstractTrees.printnode(io::IO, t::StructTree{<:AbstractEnv}) = print( io, "$(RLBase.get_name(t.x)): $(join([f(t.x) for f in RLBase.get_env_traits()], ","))", From 4d79d8cdbf1b0599d6de93d95c931f2faeddd3e8 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 13 Aug 2020 00:46:23 +0800 Subject: [PATCH 8/9] use Flux@v0.11.1 --- Project.toml | 2 +- src/extensions/Flux.jl | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 27f2492..91db577 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ BSON = "0.2" CUDA = "1" Distributions = "0.22, 0.23" FillArrays = "0.8" -Flux = "0.11" +Flux = "0.11.1" GPUArrays = "5" ImageTransformations = "0.8" JLD = "0.10" diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index 2fe992b..0475ba1 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -5,15 +5,6 @@ import Flux: glorot_uniform, glorot_normal using Random using LinearAlgebra -# watch https://github.com/FluxML/Flux.jl/issues/1274 -glorot_uniform(rng::AbstractRNG, dims...) = - (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(Flux.nfan(dims...))) -glorot_normal(rng::AbstractRNG, dims...) = - randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(Flux.nfan(dims...))) - -glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...) -glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...) - # https://github.com/FluxML/Flux.jl/pull/1171/ # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/Orthogonal function orthogonal_matrix(rng::AbstractRNG, nrow, ncol) From c2059e4cdf55404b04d3d4c1266b0b8b8bb917a6 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 13 Aug 2020 00:47:25 +0800 Subject: [PATCH 9/9] automatic copy view/reshaped-view to GPU without copying parent --- src/utils/device.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/utils/device.jl b/src/utils/device.jl index f695b8b..4463815 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -12,8 +12,11 @@ send_to_device(::Val{:cpu}, x) = x # cpu(x) is not very efficient! So by defaul send_to_device(::Val{:cpu}, x::CuArray) = adapt(Array, x) send_to_device(::Val{:gpu}, x) = Flux.fmap(a -> adapt(CuArray{Float32}, a), x) -send_to_device(::Val{:gpu}, x::SubArray{T,N,<:CircularArrayBuffer}) where {T,N} = - CuArray{T}(x) +send_to_device(::Val{:gpu}, x::Union{ + SubArray{<:Any,<:Any,<:CircularArrayBuffer}, + Base.ReshapedArray{<:Any, <:Any, <:SubArray{<:Any, <:Any, <:CircularArrayBuffer}}, + SubArray{<:Any,<:Any,<:Base.ReshapedArray{<:Any, <:Any, <:SubArray{<:Any, <:Any, <:CircularArrayBuffer}}} + }) = CuArray(x) """ device(model)