From 1a0144277bd9e95b4ade1bf4aef9383810eb760f Mon Sep 17 00:00:00 2001 From: norci Date: Mon, 1 Feb 2021 19:43:28 +0800 Subject: [PATCH] In DoEveryNEpisode, added a keyword argument stage. Then it can be used for PreEpisodeStage & PostEpisodeStage. --- src/core/hooks.jl | 10 ++++++---- test/core/hooks.jl | 10 ++++++++++ test/runtests.jl | 1 + 3 files changed, 17 insertions(+), 4 deletions(-) create mode 100644 test/core/hooks.jl diff --git a/src/core/hooks.jl b/src/core/hooks.jl index 100f918..a802f81 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -147,7 +147,7 @@ function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env) end ##### -# TotalBatchRewardPerEpisode +# TotalBatchRewardPerEpisode ##### struct TotalBatchRewardPerEpisode <: AbstractHook rewards::Vector{Vector{Float64}} @@ -260,15 +260,17 @@ end Execute `f(agent, env)` every `n` episode. `t` is a counter of steps. """ -Base.@kwdef mutable struct DoEveryNEpisode{F} <: AbstractHook +Base.@kwdef mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: + AbstractHook f::F n::Int = 1 t::Int = 0 end -DoEveryNEpisode(f, n = 1, t = 0) = DoEveryNEpisode(f, n, t) +DoEveryNEpisode(f::F, n = 1, t = 0; stage::S = POST_EPISODE_STAGE) where {S,F} = + DoEveryNEpisode{S,F}(f, n, t) -function (hook::DoEveryNEpisode)(::PostEpisodeStage, agent, env) +function (hook::DoEveryNEpisode{S})(::S, agent, env) where {S} hook.t += 1 if hook.t % hook.n == 0 hook.f(hook.t, agent, env) diff --git a/test/core/hooks.jl b/test/core/hooks.jl new file mode 100644 index 0000000..02f6319 --- /dev/null +++ b/test/core/hooks.jl @@ -0,0 +1,10 @@ +let stages = (POST_EPISODE_STAGE, PRE_EPISODE_STAGE) + @testset "DoEveryNEpisode stage=$(stage),s2=$(s2)" for stage in stages, s2 in stages + hook = DoEveryNEpisode((x...) -> true; stage) + if stage === s2 + @test hook(stage, nothing, nothing) + else + @test hook(s2, nothing, nothing) === nothing + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ff4fc19..ce41694 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,7 @@ using CUDA @testset "ReinforcementLearningCore.jl" begin include("core/core.jl") + include("core/hooks.jl") include("core/stop_conditions_test.jl") include("components/components.jl") include("utils/utils.jl")