Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions src/core/stop_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,6 @@ function (s::StopAfterEpisode)(agent, env)
s.cur >= s.episode
end

#####
# StopAfterNoImprovement
#####

mutable struct StopAfterNoImprovement{F,B<:CircularArrayBuffer,T<:Number}
fn::F
buffer::B
δ::T
peak::T
end

"""
StopAfterNoImprovement()

Expand All @@ -134,16 +123,34 @@ patience: Number of epochs with no improvement after which training will be stop

Return `true` after the monitored metric has stopped improving.
"""
function StopAfterNoImprovement(fn, patience::Int, δ::T = 0.0f0) where {T<:Number}
StopAfterNoImprovement(fn, CircularArrayBuffer{T}(1, patience), δ, T(0))
mutable struct StopAfterNoImprovement{T<:Number,F}
fn::F
patience::Int
δ::T
peak::T
counter::Int
end

function StopAfterNoImprovement(
fn,
patience::Int,
δ::T = 0.0f0,
) where {T<:Number}
StopAfterNoImprovement(fn, patience, δ, typemin(T), 1)
end

function (s::StopAfterNoImprovement)(agent, env)::Bool
is_terminated(env) || return false # post episode stage
val = s.fn()
s.peak = max(val, s.peak)
push!(s.buffer, val)
isfull(s.buffer) ? all(s.buffer .< (s.peak - s.δ)) : false
if val > s.peak - s.δ
s.counter = 1
s.peak = max(val, s.peak)
return false
else
s.counter += 1
return s.counter > s.patience
end
return false
end

#####
Expand Down