Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/core/stop_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ end
# StopAfterNoImprovement
#####

Base.@kwdef struct StopAfterNoImprovement{F,B<:CircularArrayBuffer,T<:Number}
mutable struct StopAfterNoImprovement{F,B<:CircularArrayBuffer,T<:Number}
fn::F
buffer::B
δ::T = 0.0
δ::T
peak::T
end

"""
Expand All @@ -123,7 +124,7 @@ Stop training when a monitored metric has stopped improving.
Parameters:

fn: a closure, return a scalar value, which indicates the performance of the policy (the higher the better)
e.g.
e.g.
1. () -> reward(env)
1. () -> total_reward_per_episode.reward

Expand All @@ -134,15 +135,15 @@ patience: Number of epochs with no improvement after which training will be stop
Return `true` after the monitored metric has stopped improving.
"""
function StopAfterNoImprovement(fn, patience::Int, δ::T = 0.0f0) where {T<:Number}
StopAfterNoImprovement(fn = fn, buffer = CircularArrayBuffer{T}(1, patience), δ = δ)
StopAfterNoImprovement(fn, CircularArrayBuffer{T}(1, patience), δ, T(0))
end

function (s::StopAfterNoImprovement)(agent, env)::Bool
is_terminated(env) || return false # post episode stage
val = s.fn()
improved = isfull(s.buffer) ? all(s.buffer .< (val - s.δ)) : true
s.peak = max(val, s.peak)
push!(s.buffer, val)
!improved
isfull(s.buffer) ? all(s.buffer .< (s.peak - s.δ)) : false
end

#####
Expand Down