From ee62ece9c0be8899c8965c139698a0e66f633d0b Mon Sep 17 00:00:00 2001 From: Ilan Coulon Date: Thu, 1 Apr 2021 15:54:49 +0100 Subject: [PATCH 1/4] Test length on CircularArraySLARTTrajectory --- test/components/trajectories.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl index f22a672..0d8ce9e 100644 --- a/test/components/trajectories.jl +++ b/test/components/trajectories.jl @@ -57,6 +57,12 @@ # test instance type is same as type @test isa(t, CircularArraySLARTTrajectory) + + @test length(t) == 0 + push!(t; state = ones(Int, 4), action = 1, legal_actions_mask = trues(4)) + @test length(t) == 0 + push!(t; reward = 1.0f0, terminal = false,) + @test length(t) == 1 end @testset "ReservoirTrajectory" begin From c7a6152bc0072da9ee0e3cf5d8e49d100bed23dd Mon Sep 17 00:00:00 2001 From: Ilan Coulon Date: Thu, 1 Apr 2021 15:55:42 +0100 Subject: [PATCH 2/4] Add support for CircularArraySLARTTrajectory for length method --- src/policies/agents/trajectories/trajectory.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/policies/agents/trajectories/trajectory.jl b/src/policies/agents/trajectories/trajectory.jl index a6c8795..976d75d 100644 --- a/src/policies/agents/trajectories/trajectory.jl +++ b/src/policies/agents/trajectories/trajectory.jl @@ -254,6 +254,7 @@ CircularArrayPSARTTrajectory(; capacity, kwargs...) = PrioritizedTrajectory( function Base.length( t::Union{ CircularArraySARTTrajectory, + CircularArraySLARTTrajectory, CircularVectorSARTSATrajectory, ElasticSARTTrajectory, }, From e1693939839577526aa38c0726780918c8aa6fdc Mon Sep 17 00:00:00 2001 From: Ilan Coulon Date: Thu, 1 Apr 2021 16:07:05 +0100 Subject: [PATCH 3/4] Add CircularArraySLARTTrajectory for fetch! method --- src/policies/agents/trajectories/trajectory_extension.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/policies/agents/trajectories/trajectory_extension.jl b/src/policies/agents/trajectories/trajectory_extension.jl index 69eacb0..44e80ff 100644 --- a/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/policies/agents/trajectories/trajectory_extension.jl @@ -140,7 +140,7 @@ end function fetch!( sampler::NStepBatchSampler{traces}, - traj::CircularArraySARTTrajectory, + traj::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, inds::Vector{Int}, ) where {traces} γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size From f9611a164b6c8970835e069c06bc971034a1008d Mon Sep 17 00:00:00 2001 From: Ilan Coulon Date: Thu, 1 Apr 2021 18:00:20 +0100 Subject: [PATCH 4/4] Fix length test --- test/components/trajectories.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl index 0d8ce9e..8fb8eb9 100644 --- a/test/components/trajectories.jl +++ b/test/components/trajectories.jl @@ -51,7 +51,7 @@ @testset "CircularArraySLARTTrajectory" begin t = CircularArraySLARTTrajectory( capacity = 3, - state = Matrix{Float32} => (2,2), + state = Vector{Int} => (4,), legal_actions_mask = Vector{Bool} => (4, ), ) @@ -61,7 +61,7 @@ @test length(t) == 0 push!(t; state = ones(Int, 4), action = 1, legal_actions_mask = trues(4)) @test length(t) == 0 - push!(t; reward = 1.0f0, terminal = false,) + push!(t; reward = 1.0f0, terminal = false) @test length(t) == 1 end