diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index 649a7a3..6e9fd28 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -109,10 +109,11 @@ end # StopAfterNoImprovement ##### -Base.@kwdef struct StopAfterNoImprovement{F,B<:CircularArrayBuffer,T<:Number} +mutable struct StopAfterNoImprovement{F,B<:CircularArrayBuffer,T<:Number} fn::F buffer::B - δ::T = 0.0 + δ::T + peak::T end """ @@ -123,7 +124,7 @@ Stop training when a monitored metric has stopped improving. Parameters: fn: a closure, return a scalar value, which indicates the performance of the policy (the higher the better) -e.g. +e.g. 1. () -> reward(env) 1. () -> total_reward_per_episode.reward @@ -134,15 +135,15 @@ patience: Number of epochs with no improvement after which training will be stop Return `true` after the monitored metric has stopped improving. """ function StopAfterNoImprovement(fn, patience::Int, δ::T = 0.0f0) where {T<:Number} - StopAfterNoImprovement(fn = fn, buffer = CircularArrayBuffer{T}(1, patience), δ = δ) + StopAfterNoImprovement(fn, CircularArrayBuffer{T}(1, patience), δ, T(0)) end function (s::StopAfterNoImprovement)(agent, env)::Bool is_terminated(env) || return false # post episode stage val = s.fn() - improved = isfull(s.buffer) ? all(s.buffer .< (val - s.δ)) : true + s.peak = max(val, s.peak) push!(s.buffer, val) - !improved + isfull(s.buffer) ? all(s.buffer .< (s.peak - s.δ)) : false end #####