From 189e78026472a72c522ca83a61f659ae438f372f Mon Sep 17 00:00:00 2001 From: norci Date: Fri, 29 Jan 2021 17:37:15 +0800 Subject: [PATCH] refactor StopAfterNoImprovement --- src/core/stop_conditions.jl | 39 ++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index 6e9fd28..55ca7e7 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -105,17 +105,6 @@ function (s::StopAfterEpisode)(agent, env) s.cur >= s.episode end -##### -# StopAfterNoImprovement -##### - -mutable struct StopAfterNoImprovement{F,B<:CircularArrayBuffer,T<:Number} - fn::F - buffer::B - δ::T - peak::T -end - """ StopAfterNoImprovement() @@ -134,16 +123,34 @@ patience: Number of epochs with no improvement after which training will be stop Return `true` after the monitored metric has stopped improving. """ -function StopAfterNoImprovement(fn, patience::Int, δ::T = 0.0f0) where {T<:Number} - StopAfterNoImprovement(fn, CircularArrayBuffer{T}(1, patience), δ, T(0)) +mutable struct StopAfterNoImprovement{T<:Number,F} + fn::F + patience::Int + δ::T + peak::T + counter::Int +end + +function StopAfterNoImprovement( + fn, + patience::Int, + δ::T = 0.0f0, +) where {T<:Number} + StopAfterNoImprovement(fn, patience, δ, typemin(T), 1) end function (s::StopAfterNoImprovement)(agent, env)::Bool is_terminated(env) || return false # post episode stage val = s.fn() - s.peak = max(val, s.peak) - push!(s.buffer, val) - isfull(s.buffer) ? all(s.buffer .< (s.peak - s.δ)) : false + if val > s.peak - s.δ + s.counter = 1 + s.peak = max(val, s.peak) + return false + else + s.counter += 1 + return s.counter > s.patience + end + return false end #####