Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/policies/agents/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreAct
RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) = nothing

function RLBase.update!(
trajectory::CircularArraySARTTrajectory,
trajectory::Union{CircularArraySARTTrajectory, PrioritizedTrajectory{<:CircularArraySARTTrajectory}},
::AbstractPolicy,
::AbstractEnv,
::PreEpisodeStage,
Expand All @@ -54,7 +54,7 @@ function RLBase.update!(
end

function RLBase.update!(
trajectory::CircularArraySLARTTrajectory,
trajectory::Union{CircularArraySLARTTrajectory, PrioritizedTrajectory{<:CircularArraySLARTTrajectory}},
::AbstractPolicy,
::AbstractEnv,
::PreEpisodeStage,
Expand All @@ -67,7 +67,7 @@ function RLBase.update!(
end

function RLBase.update!(
trajectory::CircularArraySARTTrajectory,
trajectory::Union{CircularArraySARTTrajectory,PrioritizedTrajectory{<:CircularArraySARTTrajectory}},
policy::AbstractPolicy,
env::AbstractEnv,
::Union{PreActStage, PostEpisodeStage},
Expand All @@ -79,7 +79,7 @@ function RLBase.update!(
end

function RLBase.update!(
trajectory::CircularArraySLARTTrajectory,
trajectory::Union{CircularArraySLARTTrajectory,PrioritizedTrajectory{<:CircularArraySLARTTrajectory}},
policy::AbstractPolicy,
env::AbstractEnv,
::Union{PreActStage, PostEpisodeStage},
Expand Down
6 changes: 3 additions & 3 deletions src/policies/agents/trajectories/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ end

#####

Base.@kwdef struct PrioritizedTrajectory{P,T} <: AbstractTrajectory
priority::P
Base.@kwdef struct PrioritizedTrajectory{T,P} <: AbstractTrajectory
traj::T
priority::P
end

Base.keys(t::PrioritizedTrajectory) = (:priority, keys(t.traj)...)
Expand All @@ -221,8 +221,8 @@ const CircularArrayPSARTTrajectory =
PrioritizedTrajectory{<:SumTree,<:CircularArraySARTTrajectory}

CircularArrayPSARTTrajectory(; capacity, kwargs...) = PrioritizedTrajectory(
SumTree(capacity),
CircularArraySARTTrajectory(; capacity = capacity, kwargs...),
SumTree(capacity),
)

#####
Expand Down
5 changes: 3 additions & 2 deletions src/policies/agents/trajectories/trajectory_extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ Base.@kwdef struct NStepBatchSampler{traces} <: AbstractSampler{traces}
end

function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler)
inds = rand(rng, 1:(length(t)-s.n+1), s.batch_size)
valid_range = isnothing(s.stack_size) ? (1:(length(t)-s.n+1)) : (s.stack_size:(length(t)-s.n+1))
inds = rand(rng, valid_range, s.batch_size)
inds, select(inds, t, s)
end

function StatsBase.sample(
rng::AbstractRNG,
t::PrioritizedTrajectory{<:SumTree},
t::PrioritizedTrajectory,
s::NStepBatchSampler,
)
bz, sz = s.batch_size, s.stack_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator
optimizer::O = nothing
end

(app::NeuralNetworkApproximator)(x) = app.model(x)
# some model may accept multiple inputs
(app::NeuralNetworkApproximator)(args...; kwargs...) = app.model(args...; kwargs...)


functor(x::NeuralNetworkApproximator) =
(model = x.model,), y -> NeuralNetworkApproximator(y.model, x.optimizer)
Expand Down
13 changes: 9 additions & 4 deletions src/utils/processors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ function (p::StackFrames{T,N})(state::AbstractArray) where {T,N}
p
end

function Base.push!(cb::CircularArrayBuffer, p::StackFrames)
push!(cb, select_last_frame(p.buffer))
end

function RLBase.reset!(p::StackFrames{T,N}) where {T,N}
fill!(p.buffer, zero(T))
p
end

"""
When pushing a `StackFrames` into a `CircularArrayBuffer` of the same dimension,
only the latest frame is pushed. If the `StackFrames` is one dimension lower,
then it is treated as a general `AbstractArray` and is pushed in as a frame.
"""
function Base.push!(cb::CircularArrayBuffer{T,N}, p::StackFrames{T,N}) where {T,N}
push!(cb, select_last_frame(p.buffer))
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using CircularArrayBuffers
using ReinforcementLearningBase
using ReinforcementLearningCore
using Random
Expand Down
4 changes: 2 additions & 2 deletions test/utils/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
@testset "sum_tree" begin
t = SumTree(8)

@test capacity(t) == 8
@test RLCore.capacity(t) == 8

for i in 1:4
push!(t, i)
Expand All @@ -74,7 +74,7 @@
@test all([get(t, v)[1] == i for (i, v) in enumerate(0.5:1.0:8)])

empty!(t)
@test capacity(t) == 8
@test RLCore.capacity(t) == 8
@test length(t) == 0
end

Expand Down
23 changes: 23 additions & 0 deletions test/utils/processors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@testset "processors" begin
@testset "StackFrames" begin
cb = CircularArrayBuffer{Float32}(2,3,4)
s = StackFrames(2,3, 2)
push!(cb, s)
@test size(cb) == (2,3,1)

s(ones(Float32, 2,3))
@test s[:, :, 1] == zeros(2,3)
@test s[:, :, 2] == ones(2,3)

push!(cb, s)
@test size(cb) == (2,3,2)

s = StackFrames(2,3) # one dimension lower
s(ones(2))
s(2 * ones(2))
s(3 * ones(2))

push!(cb, s)
@test cb[:, :, end] == [1 2 3; 1 2 3]
end
end
1 change: 1 addition & 0 deletions test/utils/utils.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("base.jl")
include("device.jl")
include("processors.jl")