From baf25d56b3d1eb80c89c020fb009130f02fa3532 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 26 Jun 2020 13:48:34 +0800 Subject: [PATCH] extend existing hook to recoganize RewardOverriddenObs --- src/core/hooks.jl | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/core/hooks.jl b/src/core/hooks.jl index 26c42df..6fb5998 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -112,6 +112,10 @@ function (hook::RewardsPerEpisode)(::PostActStage, agent, env, obs) push!(hook.rewards[end], get_reward(obs)) end +function (hook::RewardsPerEpisode)(::PostActStage, agent, env, obs::RewardOverriddenObs) + push!(hook.rewards[end], get_reward(obs.obs)) +end + function (hook::RewardsPerEpisode)(::PostEpisodeStage, agent, env, obs) @debug hook.tag REWARDS_PER_EPISODE = hook.rewards[end] end @@ -124,6 +128,9 @@ end TotalRewardPerEpisode(; rewards = Float64[], reward = 0.0, tag = "TRAINING") Store the total rewards of each episode in the field of `rewards`. + +!!! note + If the observation is a [`RewardOverriddenObs`](@ref), then the original reward is recorded. """ Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook rewards::Vector{Float64} = Float64[] @@ -135,6 +142,10 @@ function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, obs) hook.reward += get_reward(obs) end +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, obs::RewardOverriddenObs) + hook.reward += get_reward(obs.obs) +end + function (hook::TotalRewardPerEpisode)( ::Union{PostEpisodeStage,PostExperimentStage}, agent, @@ -159,14 +170,21 @@ end TotalBatchRewardPerEpisode(batch_size::Int;tag="TRAINING") Similar to [`TotalRewardPerEpisode`](@ref), but will record total rewards per episode in [`BatchObs`](@ref). + +!!! note + If the observation is a [`RewardOverriddenObs`](@ref), then the original reward is recorded. """ function TotalBatchRewardPerEpisode(batch_size::Int; tag = "TRAINING") TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size), tag) end -function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env, obs::BatchObs) +function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env, obs::BatchObs{T}) where T for i in 1:length(obs) - hook.reward[i] += get_reward(obs[i]) + if T <: RewardOverriddenObs + hook.reward[i] += get_reward(obs[i].obs) + else + hook.reward[i] += get_reward(obs[i]) + end if get_terminal(obs[i]) push!(hook.rewards[i], hook.reward[i]) hook.reward[i] = 0.0 @@ -206,14 +224,22 @@ end CumulativeReward(rewards::Vector{Float64} = [0.0], tag::String = "TRAINING") Store cumulative rewards since the beginning to the field of `rewards`. + +!!! note + If the observation is a [`RewardOverriddenObs`](@ref), then the original reward is recorded. """ Base.@kwdef struct CumulativeReward <: AbstractHook rewards::Vector{Float64} = [0.0] tag::String = "TRAINING" end -function (hook::CumulativeReward)(::PostActStage, agent, env, obs) - push!(hook.rewards, get_reward(obs) + hook.rewards[end]) +function (hook::CumulativeReward)(::PostActStage, agent, env, obs::T) where T + if T <: RewardOverriddenObs + r = get_reward(obs.obs) + else + r = get_reward(obs) + end + push!(hook.rewards, r + hook.rewards[end]) @debug hook.tag CUMULATIVE_REWARD = hook.rewards[end] end