diff --git a/src/policies/agents/agent.jl b/src/policies/agents/agent.jl index 5d38b3b..cb8fffd 100644 --- a/src/policies/agents/agent.jl +++ b/src/policies/agents/agent.jl @@ -42,7 +42,7 @@ RLBase.update!(p::AbstractPolicy, t::AbstractTrajectory, ::AbstractEnv, ::PreAct RLBase.update!(::AbstractTrajectory, ::AbstractPolicy, ::AbstractEnv, ::AbstractStage) = nothing function RLBase.update!( - trajectory::CircularArraySARTTrajectory, + trajectory::Union{CircularArraySARTTrajectory, PrioritizedTrajectory{<:CircularArraySARTTrajectory}}, ::AbstractPolicy, ::AbstractEnv, ::PreEpisodeStage, @@ -54,7 +54,7 @@ function RLBase.update!( end function RLBase.update!( - trajectory::CircularArraySLARTTrajectory, + trajectory::Union{CircularArraySLARTTrajectory, PrioritizedTrajectory{<:CircularArraySLARTTrajectory}}, ::AbstractPolicy, ::AbstractEnv, ::PreEpisodeStage, @@ -67,7 +67,7 @@ function RLBase.update!( end function RLBase.update!( - trajectory::CircularArraySARTTrajectory, + trajectory::Union{CircularArraySARTTrajectory,PrioritizedTrajectory{<:CircularArraySARTTrajectory}}, policy::AbstractPolicy, env::AbstractEnv, ::Union{PreActStage, PostEpisodeStage}, @@ -79,7 +79,7 @@ function RLBase.update!( end function RLBase.update!( - trajectory::CircularArraySLARTTrajectory, + trajectory::Union{CircularArraySLARTTrajectory,PrioritizedTrajectory{<:CircularArraySLARTTrajectory}}, policy::AbstractPolicy, env::AbstractEnv, ::Union{PreActStage, PostEpisodeStage}, diff --git a/src/policies/agents/trajectories/trajectory.jl b/src/policies/agents/trajectories/trajectory.jl index 37de65a..a95ab6a 100644 --- a/src/policies/agents/trajectories/trajectory.jl +++ b/src/policies/agents/trajectories/trajectory.jl @@ -201,9 +201,9 @@ end ##### -Base.@kwdef struct PrioritizedTrajectory{P,T} <: AbstractTrajectory - priority::P +Base.@kwdef struct PrioritizedTrajectory{T,P} <: AbstractTrajectory traj::T + priority::P end Base.keys(t::PrioritizedTrajectory) = (:priority, keys(t.traj)...) @@ -221,8 +221,8 @@ const CircularArrayPSARTTrajectory = PrioritizedTrajectory{<:SumTree,<:CircularArraySARTTrajectory} CircularArrayPSARTTrajectory(; capacity, kwargs...) = PrioritizedTrajectory( - SumTree(capacity), CircularArraySARTTrajectory(; capacity = capacity, kwargs...), + SumTree(capacity), ) ##### diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index 129e676..9634525 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -88,13 +88,14 @@ Base.@kwdef struct NStepBatchSampler{traces} <: AbstractSampler{traces} end function StatsBase.sample(rng::AbstractRNG, t::AbstractTrajectory, s::NStepBatchSampler) - inds = rand(rng, 1:(length(t)-s.n+1), s.batch_size) + valid_range = isnothing(s.stack_size) ? (1:(length(t)-s.n+1)) : (s.stack_size:(length(t)-s.n+1)) + inds = rand(rng, valid_range, s.batch_size) inds, select(inds, t, s) end function StatsBase.sample( rng::AbstractRNG, - t::PrioritizedTrajectory{<:SumTree}, + t::PrioritizedTrajectory, s::NStepBatchSampler, ) bz, sz = s.batch_size, s.stack_size diff --git a/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl index fcca965..23b3595 100644 --- a/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl +++ b/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl @@ -18,7 +18,9 @@ Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator optimizer::O = nothing end -(app::NeuralNetworkApproximator)(x) = app.model(x) +# some model may accept multiple inputs +(app::NeuralNetworkApproximator)(args...; kwargs...) = app.model(args...; kwargs...) + functor(x::NeuralNetworkApproximator) = (model = x.model,), y -> NeuralNetworkApproximator(y.model, x.optimizer) diff --git a/src/utils/processors.jl b/src/utils/processors.jl index c5599f6..7d5a09d 100644 --- a/src/utils/processors.jl +++ b/src/utils/processors.jl @@ -52,11 +52,16 @@ function (p::StackFrames{T,N})(state::AbstractArray) where {T,N} p end -function Base.push!(cb::CircularArrayBuffer, p::StackFrames) - push!(cb, select_last_frame(p.buffer)) -end - function RLBase.reset!(p::StackFrames{T,N}) where {T,N} fill!(p.buffer, zero(T)) p end + +""" +When pushing a `StackFrames` into a `CircularArrayBuffer` of the same dimension, +only the latest frame is pushed. If the `StackFrames` is one dimension lower, +then it is treated as a general `AbstractArray` and is pushed in as a frame. +""" +function Base.push!(cb::CircularArrayBuffer{T,N}, p::StackFrames{T,N}) where {T,N} + push!(cb, select_last_frame(p.buffer)) +end diff --git a/test/Project.toml b/test/Project.toml index 988160a..837a81c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index a563f68..f752f46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using CircularArrayBuffers using ReinforcementLearningBase using ReinforcementLearningCore using Random diff --git a/test/utils/base.jl b/test/utils/base.jl index 1d8618d..4e893ce 100644 --- a/test/utils/base.jl +++ b/test/utils/base.jl @@ -52,7 +52,7 @@ @testset "sum_tree" begin t = SumTree(8) - @test capacity(t) == 8 + @test RLCore.capacity(t) == 8 for i in 1:4 push!(t, i) @@ -74,7 +74,7 @@ @test all([get(t, v)[1] == i for (i, v) in enumerate(0.5:1.0:8)]) empty!(t) - @test capacity(t) == 8 + @test RLCore.capacity(t) == 8 @test length(t) == 0 end diff --git a/test/utils/processors.jl b/test/utils/processors.jl new file mode 100644 index 0000000..105eea4 --- /dev/null +++ b/test/utils/processors.jl @@ -0,0 +1,23 @@ +@testset "processors" begin + @testset "StackFrames" begin + cb = CircularArrayBuffer{Float32}(2,3,4) + s = StackFrames(2,3, 2) + push!(cb, s) + @test size(cb) == (2,3,1) + + s(ones(Float32, 2,3)) + @test s[:, :, 1] == zeros(2,3) + @test s[:, :, 2] == ones(2,3) + + push!(cb, s) + @test size(cb) == (2,3,2) + + s = StackFrames(2,3) # one dimension lower + s(ones(2)) + s(2 * ones(2)) + s(3 * ones(2)) + + push!(cb, s) + @test cb[:, :, end] == [1 2 3; 1 2 3] + end +end \ No newline at end of file diff --git a/test/utils/utils.jl b/test/utils/utils.jl index e122ec8..a1ad204 100644 --- a/test/utils/utils.jl +++ b/test/utils/utils.jl @@ -1,2 +1,3 @@ include("base.jl") include("device.jl") +include("processors.jl")