Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit c70f865

Browse files
authored
add a new stop condition (#220)
which sets the time badget of each experiment.
1 parent 43a010d commit c70f865

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/core/stop_conditions.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ export StopAfterStep,
33
StopWhenDone,
44
ComposedStopCondition,
55
StopSignal,
6-
StopAfterNoImprovement
6+
StopAfterNoImprovement,
7+
StopAfterNSeconds
78

89
using ProgressMeter
910
using CircularArrayBuffers: CircularArrayBuffer, isfull
@@ -185,3 +186,26 @@ Base.getindex(s::StopSignal) = s.is_stop[]
185186
Base.setindex!(s::StopSignal, v::Bool) = s.is_stop[] = v
186187

187188
(s::StopSignal)(agent, env) = s[]
189+
190+
"""
191+
StopAfterNSeconds
192+
193+
parameter:
194+
1. time badget
195+
196+
stop training after N seconds
197+
198+
"""
199+
Base.@kwdef mutable struct StopAfterNSeconds
200+
budget::Float64
201+
deadline::Float64 = 0.0
202+
end
203+
function RLBase.reset!(s::StopAfterNSeconds)
204+
s.deadline = time() + s.budget
205+
s
206+
end
207+
function StopAfterNSeconds(budget::Float64)
208+
s = StopAfterNSeconds(; budget)
209+
RLBase.reset!(s)
210+
end
211+
(s::StopAfterNSeconds)(_...) = time() > s.deadline

test/core/stop_conditions_test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,10 @@
1616

1717
@test argmax(total_reward_per_episode.rewards) != patience
1818
end
19+
20+
@testset "StopAfterNSeconds" begin
21+
s = StopAfterNSeconds(0.01)
22+
@test !s()
23+
sleep(0.02)
24+
@test s()
25+
end

0 commit comments

Comments
 (0)