diff --git a/src/policies/agents/trajectories/trajectory.jl b/src/policies/agents/trajectories/trajectory.jl index a6c8795..976d75d 100644 --- a/src/policies/agents/trajectories/trajectory.jl +++ b/src/policies/agents/trajectories/trajectory.jl @@ -254,6 +254,7 @@ CircularArrayPSARTTrajectory(; capacity, kwargs...) = PrioritizedTrajectory( function Base.length( t::Union{ CircularArraySARTTrajectory, + CircularArraySLARTTrajectory, CircularVectorSARTSATrajectory, ElasticSARTTrajectory, }, diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index 69eacb0..44e80ff 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -140,7 +140,7 @@ end function fetch!( sampler::NStepBatchSampler{traces}, - traj::CircularArraySARTTrajectory, + traj::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, inds::Vector{Int}, ) where {traces} γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl index f22a672..8fb8eb9 100644 --- a/test/components/trajectories.jl +++ b/test/components/trajectories.jl @@ -51,12 +51,18 @@ @testset "CircularArraySLARTTrajectory" begin t = CircularArraySLARTTrajectory( capacity = 3, - state = Matrix{Float32} => (2,2), + state = Vector{Int} => (4,), legal_actions_mask = Vector{Bool} => (4, ), ) # test instance type is same as type @test isa(t, CircularArraySLARTTrajectory) + + @test length(t) == 0 + push!(t; state = ones(Int, 4), action = 1, legal_actions_mask = trues(4)) + @test length(t) == 0 + push!(t; reward = 1.0f0, terminal = false) + @test length(t) == 1 end @testset "ReservoirTrajectory" begin