From 52c17289c5187f6578c2b889dd5e7723f4b1b350 Mon Sep 17 00:00:00 2001 From: norci Date: Sun, 21 Feb 2021 22:43:14 +0800 Subject: [PATCH] add a new stop condition which sets the time badget of each experiment. --- src/core/stop_conditions.jl | 26 +++++++++++++++++++++++++- test/core/stop_conditions_test.jl | 7 +++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index 55ca7e7..308c412 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -3,7 +3,8 @@ export StopAfterStep, StopWhenDone, ComposedStopCondition, StopSignal, - StopAfterNoImprovement + StopAfterNoImprovement, + StopAfterNSeconds using ProgressMeter using CircularArrayBuffers: CircularArrayBuffer, isfull @@ -185,3 +186,26 @@ Base.getindex(s::StopSignal) = s.is_stop[] Base.setindex!(s::StopSignal, v::Bool) = s.is_stop[] = v (s::StopSignal)(agent, env) = s[] + +""" +StopAfterNSeconds + +parameter: +1. time badget + +stop training after N seconds + +""" +Base.@kwdef mutable struct StopAfterNSeconds + budget::Float64 + deadline::Float64 = 0.0 +end +function RLBase.reset!(s::StopAfterNSeconds) + s.deadline = time() + s.budget + s +end +function StopAfterNSeconds(budget::Float64) + s = StopAfterNSeconds(; budget) + RLBase.reset!(s) +end +(s::StopAfterNSeconds)(_...) = time() > s.deadline diff --git a/test/core/stop_conditions_test.jl b/test/core/stop_conditions_test.jl index 92efc0a..c186781 100644 --- a/test/core/stop_conditions_test.jl +++ b/test/core/stop_conditions_test.jl @@ -16,3 +16,10 @@ @test argmax(total_reward_per_episode.rewards) != patience end + +@testset "StopAfterNSeconds" begin + s = StopAfterNSeconds(0.01) + @test !s() + sleep(0.02) + @test s() +end