diff --git a/Project.toml b/Project.toml index b9c4233..9ef979f 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" diff --git a/src/components/trajectories/trajectory.jl b/src/components/trajectories/trajectory.jl index da06ecb..b99a6af 100644 --- a/src/components/trajectories/trajectory.jl +++ b/src/components/trajectories/trajectory.jl @@ -3,13 +3,18 @@ export Trajectory, EpisodicTrajectory, CombinedTrajectory, CircularCompactSATrajectory, + VectCompactSATrajectory, + ElasticCompactSATrajectory, CircularCompactSALTrajectory, CircularCompactSARTSATrajectory, + VectCompactSARTSATrajectory, + ElasticCompactSARTSATrajectory, CircularCompactPSARTSATrajectory, CircularCompactSALRTSALTrajectory, CircularCompactPSALRTSALTrajectory using MacroTools: @forward +using ElasticArrays ##### # Trajectory @@ -175,6 +180,37 @@ end isfull(t::CombinedTrajectory) = isfull(t.t1) && isfull(t.t2) +##### +# VectCompactSATrajectory +##### + +const VectCompactSATrajectory = CombinedTrajectory{ + <:SharedTrajectory{ + <:Vector, + <:NamedTuple{(:state, :next_state, :full_state)}, + }, + <:SharedTrajectory{ + <:Vector, + <:NamedTuple{(:action, :next_action, :full_action)}, + } +} + +function VectCompactSATrajectory(; + state_type = Int, + action_type = Int, + ) + CombinedTrajectory( + SharedTrajectory( + Vector{state_type}(), + :state, + ), + SharedTrajectory( + Vector{action_type}(), + :action, + ), + ) +end + ##### # CircularCompactSATrajectory ##### @@ -209,6 +245,40 @@ function CircularCompactSATrajectory(; ) end +##### +# ElasticCompactSATrajectory +##### + +const ElasticCompactSATrajectory = CombinedTrajectory{ + <:SharedTrajectory{ + <:ElasticArray, + <:NamedTuple{(:state, :next_state, :full_state)}, + }, + <:SharedTrajectory{ + <:ElasticArray, + <:NamedTuple{(:action, :next_action, :full_action)}, + }, +} + +function ElasticCompactSATrajectory(; + state_type = Int, + state_size = (), + action_type = Int, + action_size = (), +) + CombinedTrajectory( + SharedTrajectory( + ElasticArray{state_type}(undef, state_size..., 0), + :state, + ), + SharedTrajectory( + ElasticArray{action_type}(undef, action_size..., 0), + :action, + ), + ) +end + + ##### # CircularCompactSALTrajectory ##### @@ -240,6 +310,35 @@ function CircularCompactSALTrajectory(; CircularCompactSATrajectory(; capacity = capacity, kw...), ) end + +##### +# VectCompactSARTSATrajectory +##### + +const VectCompactSARTSATrajectory = CombinedTrajectory{ + <:Trajectory{ + <:NamedTuple{ + (:reward, :terminal), + <:Tuple{<:Vector,<:Vector}, + }, + }, + <:VectCompactSATrajectory, +} + +function VectCompactSARTSATrajectory(; + reward_type = Float32, + terminal_type = Bool, + kw..., +) + CombinedTrajectory( + Trajectory( + reward = Vector{reward_type}(), + terminal = Vector{terminal_type}(), + ), + VectCompactSATrajectory(; kw...), + ) +end + ##### # CircularCompactSARTSATrajectory ##### @@ -271,6 +370,37 @@ function CircularCompactSARTSATrajectory(; ) end +##### +# ElasticCompactSARTSATrajectory +##### + +const ElasticCompactSARTSATrajectory = CombinedTrajectory{ + <:Trajectory{ + <:NamedTuple{ + (:reward, :terminal), + <:Tuple{<:ElasticArray,<:ElasticArray}, + }, + }, + <:ElasticCompactSATrajectory, +} + +function ElasticCompactSARTSATrajectory(; + reward_type = Float32, + reward_size = (), + terminal_type = Bool, + terminal_size = (), + kw..., +) + CombinedTrajectory( + Trajectory( + reward = ElasticArray{reward_type}(undef, reward_size..., 0), + terminal = ElasticArray{terminal_type}(undef, terminal_size..., 0), + ), + ElasticCompactSATrajectory(; kw...), + ) +end + + ##### # CircularCompactSALRTSALTrajectory ##### diff --git a/src/extensions/ElasticArrays.jl b/src/extensions/ElasticArrays.jl new file mode 100644 index 0000000..5820b2b --- /dev/null +++ b/src/extensions/ElasticArrays.jl @@ -0,0 +1,11 @@ +using ElasticArrays + +Base.push!(a::ElasticArray, x) = append!(a, x) +Base.empty!(a::ElasticArray) = ElasticArrays.resize_lastdim!(A, 0) + +function Base.pop!(a::ElasticArray) + # ??? Is it safe to do so? + last_frame = selectdim(a, ndims(a), size(a, ndims(a))) + ElasticArrays.resize_lastdim!(A, size(a, ndims(a))-1) + last_frame +end \ No newline at end of file diff --git a/src/extensions/extensions.jl b/src/extensions/extensions.jl index 3fab88b..9375b69 100644 --- a/src/extensions/extensions.jl +++ b/src/extensions/extensions.jl @@ -2,3 +2,4 @@ include("Flux.jl") include("CUDA.jl") include("Zygote.jl") include("ReinforcementLearningBase.jl") +include("ElasticArrays.jl") diff --git a/test/components/trajectories.jl b/test/components/trajectories.jl index 5793dc2..844a2b2 100644 --- a/test/components/trajectories.jl +++ b/test/components/trajectories.jl @@ -155,4 +155,32 @@ @test t[:full_state] == [] @test t[:full_action] == [] end + + @testset "VectCompactSARTSATrajectory" begin + t = VectCompactSARTSATrajectory(;state_type=Vector{Float32}, action_type=Int, reward_type=Float32, terminal_type=Bool) + push!(t; state=Float32[1,1], action=1) + push!(t; reward=1f0, terminal=false, state=Float32[2,2], action=2) + push!(t; reward=2f0, terminal=true, state=Float32[3,3], action=3) + + @test t[:state] == [Float32[1,1], Float32[2,2]] + @test t[:action] == [1,2] + @test t[:reward] == [1f0,2f0] + @test t[:terminal] == [false,true] + @test t[:next_state] == [Float32[2,2], Float32[3,3]] + @test t[:next_action] == [2,3] + end + + @testset "ElasticCompactSARTSATrajectory" begin + t = ElasticCompactSARTSATrajectory(;state_type=Float32, state_size=(2,), action_type=Int, reward_type=Float32, terminal_type=Bool) + push!(t; state=Float32[1,1], action=1) + push!(t; reward=1f0, terminal=false, state=Float32[2,2], action=2) + push!(t; reward=2f0, terminal=true, state=Float32[3,3], action=3) + + @test t[:state] == Float32[1 2; 1 2] + @test t[:action] == [1, 2] + @test t[:reward] == [1f0, 2f0] + @test t[:terminal] == [false, true] + @test t[:next_state] == Float32[2 3; 2 3] + @test t[:next_action] == [2, 3] + end end