Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 6aa3b6b

Browse files
authored
In DoEveryNEpisode, added a keyword argument stage. (#214)
Then it can be used for PreEpisodeStage & PostEpisodeStage.
1 parent 513d577 commit 6aa3b6b

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

src/core/hooks.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env)
147147
end
148148

149149
#####
150-
# TotalBatchRewardPerEpisode
150+
# TotalBatchRewardPerEpisode
151151
#####
152152
struct TotalBatchRewardPerEpisode <: AbstractHook
153153
rewards::Vector{Vector{Float64}}
@@ -260,15 +260,17 @@ end
260260
Execute `f(agent, env)` every `n` episode.
261261
`t` is a counter of steps.
262262
"""
263-
Base.@kwdef mutable struct DoEveryNEpisode{F} <: AbstractHook
263+
Base.@kwdef mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <:
264+
AbstractHook
264265
f::F
265266
n::Int = 1
266267
t::Int = 0
267268
end
268269

269-
DoEveryNEpisode(f, n = 1, t = 0) = DoEveryNEpisode(f, n, t)
270+
DoEveryNEpisode(f::F, n = 1, t = 0; stage::S = POST_EPISODE_STAGE) where {S,F} =
271+
DoEveryNEpisode{S,F}(f, n, t)
270272

271-
function (hook::DoEveryNEpisode)(::PostEpisodeStage, agent, env)
273+
function (hook::DoEveryNEpisode{S})(::S, agent, env) where {S}
272274
hook.t += 1
273275
if hook.t % hook.n == 0
274276
hook.f(hook.t, agent, env)

test/core/hooks.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
let stages = (POST_EPISODE_STAGE, PRE_EPISODE_STAGE)
2+
@testset "DoEveryNEpisode stage=$(stage),s2=$(s2)" for stage in stages, s2 in stages
3+
hook = DoEveryNEpisode((x...) -> true; stage)
4+
if stage === s2
5+
@test hook(stage, nothing, nothing)
6+
else
7+
@test hook(s2, nothing, nothing) === nothing
8+
end
9+
end
10+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using CUDA
1212

1313
@testset "ReinforcementLearningCore.jl" begin
1414
include("core/core.jl")
15+
include("core/hooks.jl")
1516
include("core/stop_conditions_test.jl")
1617
include("components/components.jl")
1718
include("utils/utils.jl")

0 commit comments

Comments
 (0)