From 7e9f690fdca76f08ae7da3a79707f74754d73330 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 5 Jul 2020 00:00:57 +0800 Subject: [PATCH 01/17] sync --- Project.toml | 1 + src/components/agents/agent.jl | 9 +++- .../policies/chance_player_policy.jl | 18 +++++++ src/components/policies/policies.jl | 2 + src/components/policies/tabular_policy.jl | 30 +++++++++++ .../trajectories/dummy_trajectory.jl | 5 ++ src/components/trajectories/trajectories.jl | 1 + src/core/run.jl | 50 +++++++++++++++++++ src/core/stop_conditions.jl | 2 +- 9 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 src/components/policies/chance_player_policy.jl create mode 100644 src/components/policies/tabular_policy.jl create mode 100644 src/components/trajectories/dummy_trajectory.jl diff --git a/Project.toml b/Project.toml index 67752ed..6dc255b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Jun Tian "] version = "0.3.3" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index a8fe30b..b8da1e8 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -20,7 +20,7 @@ Generally speaking, it does nothing but update the trajectory and policy appropr """ Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent policy::P - trajectory::T + trajectory::T = DummyTrajectory() role::R = :DEFAULT_PLAYER is_training::Bool = true end @@ -75,6 +75,13 @@ end (agent::Agent)(::Testing, obs) = nothing (agent::Agent)(::Testing{PreActStage}, obs) = agent.policy(obs) +##### +# DummyTrajectory +##### + +(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::AbstractStage, obs) = nothing +(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::PreActStage, obs) = agent.policy(obs) + ##### # EpisodicCompactSARTSATrajectory ##### diff --git a/src/components/policies/chance_player_policy.jl b/src/components/policies/chance_player_policy.jl new file mode 100644 index 0000000..ae0d701 --- /dev/null +++ b/src/components/policies/chance_player_policy.jl @@ -0,0 +1,18 @@ +export ChancePlayerPolicy + +using Random + +struct ChancePlayerPolicy <: AbstractPolicy + rng::AbstractRNG +end + +ChancePlayerPolicy(;seed=nothing) = ChancePlayerPolicy(MersenneTwister(seed)) + +function (p::ChancePlayerPolicy)(obs) + v = rand(p.rng) + s = 0. + for (action, prob) in get_chance_outcome(obs) + s += prob + s >= v && return action + end +end diff --git a/src/components/policies/policies.jl b/src/components/policies/policies.jl index ab599f9..4c610ae 100644 --- a/src/components/policies/policies.jl +++ b/src/components/policies/policies.jl @@ -1,3 +1,5 @@ +include("chance_player_policy.jl") +include("tabular_policy.jl") include("V_based_policy.jl") include("Q_based_policy.jl") include("off_policy.jl") diff --git a/src/components/policies/tabular_policy.jl b/src/components/policies/tabular_policy.jl new file mode 100644 index 0000000..a22c742 --- /dev/null +++ b/src/components/policies/tabular_policy.jl @@ -0,0 +1,30 @@ +export TabularPolicy + +using AbstractTrees + +struct TabularPolicy{S,F,E} <: RLBase.AbstractPolicy + probs::Dict{S,Vector{Float64}} + key::F + explorer::E +end + +(p::TabularPolicy)(obs) = p.probs[p.key(obs)] |> p.explorer + +RLBase.get_prob(p::TabularPolicy, obs) = p.probs[p.key(obs)] + +function TabularPolicy(env::AbstractEnv;key=RLBase.get_state, explorer=WeightedExplorer(;is_normalized=true)) + k = key(observe(env)) + probs = Dict{typeof(k),Vector{Float64}}() + for x in PreOrderDFS(env) + if get_current_player(x) != get_chance_player(x) + obs = observe(x) + if !get_terminal(obs) + legal_actions_mask = get_legal_actions_mask(obs) + p = zeros(length(legal_actions_mask)) + p[legal_actions_mask] .= 1 / sum(legal_actions_mask) + probs[key(obs)] = p + end + end + end + TabularPolicy(probs, key, explorer) +end diff --git a/src/components/trajectories/dummy_trajectory.jl b/src/components/trajectories/dummy_trajectory.jl new file mode 100644 index 0000000..8a52e3b --- /dev/null +++ b/src/components/trajectories/dummy_trajectory.jl @@ -0,0 +1,5 @@ +export DummyTrajectory + +struct DummyTrajectory <: AbstractTrajectory{(), Tuple{}} end + +Base.length(t::DummyTrajectory) = 0 \ No newline at end of file diff --git a/src/components/trajectories/trajectories.jl b/src/components/trajectories/trajectories.jl index 050a8f5..a64e21d 100644 --- a/src/components/trajectories/trajectories.jl +++ b/src/components/trajectories/trajectories.jl @@ -1,4 +1,5 @@ include("abstract_trajectory.jl") +include("dummy_trajectory.jl") include("trajectory.jl") include("vectorial_trajectory.jl") include("circular_trajectory.jl") diff --git a/src/core/run.jl b/src/core/run.jl index f11edb0..20d389e 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -142,3 +142,53 @@ function run( end hooks end + +function run( + ::Simultaneous, + agents::Tuple{Vararg{<:AbstractAgent}}, + env::AbstractEnv, + stop_condition, + hook = EmptyHook(), +) + reset!(env) + observations = [observe(env, get_role(agent)) for agent in agents] + hook(PRE_EPISODE_STAGE, agents, env, observations) + actions = [agent(obs) for obs in observations] + hook(PRE_ACT_STAGE, agents, env, observations, actions) + + while true + env(actions) + + for (i, agent) in enumerate(agents) + observations[i] = observe(env, get_role(agent)) + end + hook(POST_ACT_STAGE, agents, env, observations) + + if get_terminal(observations[1]) + for (agent, obs) in zip(agents, observations) + agent(POST_EPISODE_STAGE, obs) + end + hook(POST_EPISODE_STAGE, agents, env, observations) + + stop_condition(agents, env, observations) && break + + reset!(env) + + for (i, agent) in enumerate(agents) + observations[i] = observe(env, get_role(agent)) + end + hook(PRE_EPISODE_STAGE, agents, env, observations) + for (i, agent) in enumerate(agents) + actions[i] = agent(observations[i]) + end + hook(PRE_ACT_STAGE, agents, env, observations, actions) + else + stop_condition(agents, env, observations) && break + for (i, agent) in enumerate(agents) + actions[i] = agent(observations[i]) + end + hook(PRE_ACT_STAGE, agents, env, observations, actions) + end + end + hook +end diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index 60e1e46..4671050 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -93,7 +93,7 @@ end function (s::StopAfterEpisode)(agent, env, obs) @debug s.tag EPISODE = s.cur - is_terminal = get_num_players(env) == 1 ? get_terminal(obs) : get_terminal(obs[1]) + is_terminal = length(get_players(env)) == 1 ? get_terminal(obs) : get_terminal(obs[1]) if is_terminal s.cur += 1 From 56794b5469a1c60d3a206e9fe2833970cadcf852 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 19 Jul 2020 00:17:21 +0800 Subject: [PATCH 02/17] update RLBase to newer version --- Project.toml | 12 +- src/components/components.jl | 2 +- .../{preprocessors.jl => processors.jl} | 4 +- src/core/hooks.jl | 111 ++++++------ src/core/run.jl | 162 ++++++------------ src/core/stop_conditions.jl | 31 ++-- src/utils/printing.jl | 37 ++++ src/utils/utils.jl | 1 + 8 files changed, 154 insertions(+), 206 deletions(-) rename src/components/{preprocessors.jl => processors.jl} (93%) create mode 100644 src/utils/printing.jl diff --git a/Project.toml b/Project.toml index 6dc255b..4007c38 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,11 @@ version = "0.3.3" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" -CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -26,14 +23,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Adapt = "1, 2" BSON = "0.2" -CUDAapi = "3, 4" -CuArrays = "1.7, 2" Distributions = "0.22, 0.23" FillArrays = "0.8" -Flux = "0.10" -GPUArrays = "2, 3, 4.0" +Flux = "0.11" ImageTransformations = "0.8" JLD = "0.10" MacroTools = "0.5" @@ -41,7 +34,6 @@ ProgressMeter = "1.2" ReinforcementLearningBase = "0.7" Setfield = "0.6" StatsBase = "0.32, 0.33" -Zygote = "0.4" julia = "1.3" [extras] diff --git a/src/components/components.jl b/src/components/components.jl index dfc1fd5..de09f7b 100644 --- a/src/components/components.jl +++ b/src/components/components.jl @@ -1,4 +1,4 @@ -include("preprocessors.jl") +include("processors.jl") include("trajectories/trajectories.jl") include("approximators/approximators.jl") include("explorers/explorers.jl") diff --git a/src/components/preprocessors.jl b/src/components/processors.jl similarity index 93% rename from src/components/preprocessors.jl rename to src/components/processors.jl index 3b10699..0a9b94b 100644 --- a/src/components/preprocessors.jl +++ b/src/components/processors.jl @@ -9,7 +9,7 @@ using ImageTransformations: imresize! Using BSpline method to resize the `state` field of an observation to size of `img` (or `dims`). """ -struct ResizeImage{T,N} <: AbstractPreprocessor +struct ResizeImage{T,N} img::Array{T,N} end @@ -26,7 +26,7 @@ end Use a pre-initialized [`CircularArrayBuffer`](@ref) to store the latest several states specified by `d`. Before processing any observation, the buffer is filled with `zero{T}`. """ -struct StackFrames{T,N} <: AbstractPreprocessor +struct StackFrames{T,N} buffer::CircularArrayBuffer{T,N} end diff --git a/src/core/hooks.jl b/src/core/hooks.jl index 6fb5998..bc6e503 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -15,10 +15,10 @@ export AbstractHook, A hook is called at different stage duiring a [`run`](@ref) to allow users to inject customized runtime logic. By default, a `AbstractHook` will do nothing. One can override the behavior by implementing the following methods: -- `(hook::YourHook)(::PreActStage, agent, env, obs, action)`, note that there's an extra argument of `action`. -- `(hook::YourHook)(::PostActStage, agent, env, obs)` -- `(hook::YourHook)(::PreEpisodeStage, agent, env, obs)` -- `(hook::YourHook)(::PostEpisodeStage, agent, env, obs)` +- `(hook::YourHook)(::PreActStage, agent, env, action)`, note that there's an extra argument of `action`. +- `(hook::YourHook)(::PostActStage, agent, env)` +- `(hook::YourHook)(::PreEpisodeStage, agent, env)` +- `(hook::YourHook)(::PostEpisodeStage, agent, env)` """ abstract type AbstractHook end @@ -65,29 +65,24 @@ const EMPTY_HOOK = EmptyHook() ##### """ - StepsPerEpisode(; steps = Int[], count = 0, tag = "TRAINING") + StepsPerEpisode(; steps = Int[], count = 0) Store steps of each episode in the field of `steps`. """ Base.@kwdef mutable struct StepsPerEpisode <: AbstractHook steps::Vector{Int} = Int[] count::Int = 0 - tag::String = "TRAINING" end -function (hook::StepsPerEpisode)(::PostActStage, args...) - hook.count += 1 -end +(hook::StepsPerEpisode)(::PostActStage, args...) = hook.count += 1 function (hook::StepsPerEpisode)( ::Union{PostEpisodeStage,PostExperimentStage}, agent, - env, - obs, + env ) push!(hook.steps, hook.count) hook.count = 0 - @debug hook.tag STEPS_PER_EPISODE = hook.steps[end] end ##### @@ -95,29 +90,24 @@ end ##### """ - RewardsPerEpisode(; rewards = Vector{Vector{Float64}}(), tag = "TRAINING") + RewardsPerEpisode(; rewards = Vector{Vector{Float64}}()) Store each reward of each step in every episode in the field of `rewards`. """ Base.@kwdef mutable struct RewardsPerEpisode <: AbstractHook rewards::Vector{Vector{Float64}} = Vector{Vector{Float64}}() - tag::String = "TRAINING" end -function (hook::RewardsPerEpisode)(::PreEpisodeStage, agent, env, obs) +function (hook::RewardsPerEpisode)(::PreEpisodeStage, agent, env) push!(hook.rewards, []) end -function (hook::RewardsPerEpisode)(::PostActStage, agent, env, obs) - push!(hook.rewards[end], get_reward(obs)) -end - -function (hook::RewardsPerEpisode)(::PostActStage, agent, env, obs::RewardOverriddenObs) - push!(hook.rewards[end], get_reward(obs.obs)) +function (hook::RewardsPerEpisode)(::PostActStage, agent, env) + push!(hook.rewards[end], get_reward(env)) end -function (hook::RewardsPerEpisode)(::PostEpisodeStage, agent, env, obs) - @debug hook.tag REWARDS_PER_EPISODE = hook.rewards[end] +function (hook::RewardsPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv) + push!(hook.rewards[end], get_reward(env.env)) end ##### @@ -125,36 +115,33 @@ end ##### """ - TotalRewardPerEpisode(; rewards = Float64[], reward = 0.0, tag = "TRAINING") + TotalRewardPerEpisode(; rewards = Float64[], reward = 0.0) Store the total rewards of each episode in the field of `rewards`. !!! note - If the observation is a [`RewardOverriddenObs`](@ref), then the original reward is recorded. + If the environment is a [`RewardOverriddenenv`](@ref), then the original reward is recorded. """ Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook rewards::Vector{Float64} = Float64[] reward::Float64 = 0.0 - tag::String = "TRAINING" end -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, obs) - hook.reward += get_reward(obs) +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env) + hook.reward += get_reward(env) end -function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env, obs::RewardOverriddenObs) - hook.reward += get_reward(obs.obs) +function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv) + hook.reward += get_reward(env.env) end function (hook::TotalRewardPerEpisode)( ::Union{PostEpisodeStage,PostExperimentStage}, agent, env, - obs, ) push!(hook.rewards, hook.reward) hook.reward = 0 - @debug hook.tag REWARD_PER_EPISODE = hook.rewards[end] end ##### @@ -163,29 +150,28 @@ end struct TotalBatchRewardPerEpisode <: AbstractHook rewards::Vector{Vector{Float64}} reward::Vector{Float64} - tag::String end """ - TotalBatchRewardPerEpisode(batch_size::Int;tag="TRAINING") + TotalBatchRewardPerEpisode(batch_size::Int) -Similar to [`TotalRewardPerEpisode`](@ref), but will record total rewards per episode in [`BatchObs`](@ref). +Similar to [`TotalRewardPerEpisode`](@ref), but will record total rewards per episode in [`MultiThreadEnv`](@ref). !!! note - If the observation is a [`RewardOverriddenObs`](@ref), then the original reward is recorded. + If the environment is a [`RewardOverriddenEnv`](@ref), then the original reward is recorded. """ -function TotalBatchRewardPerEpisode(batch_size::Int; tag = "TRAINING") - TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size), tag) +function TotalBatchRewardPerEpisode(batch_size::Int) + TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size)) end -function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env, obs::BatchObs{T}) where T - for i in 1:length(obs) - if T <: RewardOverriddenObs - hook.reward[i] += get_reward(obs[i].obs) +function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env::MultiThreadEnv{T}) where T + for i in 1:length(env) + if T <: RewardOverriddenEnv + hook.reward[i] += get_reward(env[i].env) else - hook.reward[i] += get_reward(obs[i]) + hook.reward[i] += get_reward(env[i]) end - if get_terminal(obs[i]) + if get_terminal(env[i]) push!(hook.rewards[i], hook.reward[i]) hook.reward[i] = 0.0 end @@ -200,16 +186,16 @@ end """ BatchStepsPerEpisode(batch_size::Int; tag = "TRAINING") -Similar to [`StepsPerEpisode`](@ref), but only work for [`BatchObs`](@ref) +Similar to [`StepsPerEpisode`](@ref), but only work for [`MultiThreadEnv`](@ref) """ -function BatchStepsPerEpisode(batch_size::Int; tag = "TRAINING") +function BatchStepsPerEpisode(batch_size::Int) BatchStepsPerEpisode([Int[] for _ in 1:batch_size], zeros(Int, batch_size)) end -function (hook::BatchStepsPerEpisode)(::PostActStage, agent, env, obs::BatchObs) - for i in 1:length(obs) +function (hook::BatchStepsPerEpisode)(::PostActStage, agent, env::MultiThreadEnv) + for i in 1:length(env) hook.step[i] += 1 - if get_terminal(obs[i]) + if get_terminal(env[i]) push!(hook.steps[i], hook.step[i]) hook.step[i] = 0 end @@ -221,23 +207,22 @@ end ##### """ - CumulativeReward(rewards::Vector{Float64} = [0.0], tag::String = "TRAINING") + CumulativeReward(rewards::Vector{Float64} = [0.0]) Store cumulative rewards since the beginning to the field of `rewards`. !!! note - If the observation is a [`RewardOverriddenObs`](@ref), then the original reward is recorded. + If the environment is a [`RewardOverriddenEnv`](@ref), then the original reward is recorded instead. """ Base.@kwdef struct CumulativeReward <: AbstractHook rewards::Vector{Float64} = [0.0] - tag::String = "TRAINING" end -function (hook::CumulativeReward)(::PostActStage, agent, env, obs::T) where T - if T <: RewardOverriddenObs - r = get_reward(obs.obs) +function (hook::CumulativeReward)(::PostActStage, agent, env::T) where T + if T <: RewardOverriddenEnv + r = get_reward(env.env) else - r = get_reward(obs) + r = get_reward(env) end push!(hook.rewards, r + hook.rewards[end]) @debug hook.tag CUMULATIVE_REWARD = hook.rewards[end] @@ -261,7 +246,7 @@ end TimePerStep(; max_steps = 100) = TimePerStep(CircularArrayBuffer{Float64}(max_steps), time_ns()) -function (hook::TimePerStep)(::PostActStage, agent, env, obs) +function (hook::TimePerStep)(::PostActStage, agent, env) push!(hook.times, (time_ns() - hook.t) / 1e9) hook.t = time_ns() end @@ -269,7 +254,7 @@ end """ DoEveryNStep(f; n=1, t=0) -Execute `f(agent, env, obs)` every `n` step. +Execute `f(agent, env)` every `n` step. `t` is a counter of steps. """ Base.@kwdef mutable struct DoEveryNStep{F} <: AbstractHook @@ -280,17 +265,17 @@ end DoEveryNStep(f, n = 1, t = 0) = DoEveryNStep(f, n, t) -function (hook::DoEveryNStep)(::PostActStage, agent, env, obs) +function (hook::DoEveryNStep)(::PostActStage, agent, env) hook.t += 1 if hook.t % hook.n == 0 - hook.f(hook.t, agent, env, obs) + hook.f(hook.t, agent, env) end end """ DoEveryNEpisode(f; n=1, t=0) -Execute `f(agent, env, obs)` every `n` episode. +Execute `f(agent, env)` every `n` episode. `t` is a counter of steps. """ Base.@kwdef mutable struct DoEveryNEpisode{F} <: AbstractHook @@ -301,9 +286,9 @@ end DoEveryNEpisode(f, n = 1, t = 0) = DoEveryNEpisode(f, n, t) -function (hook::DoEveryNEpisode)(::PostEpisodeStage, agent, env, obs) +function (hook::DoEveryNEpisode)(::PostEpisodeStage, agent, env) hook.t += 1 if hook.t % hook.n == 0 - hook.f(hook.t, agent, env, obs) + hook.f(hook.t, agent, env) end end diff --git a/src/core/run.jl b/src/core/run.jl index 20d389e..5c23202 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -1,9 +1,10 @@ import Base: run -run(agent, env::AbstractEnv, args...) = run(DynamicStyle(env), agent, env, args...) +run(agent, env::AbstractEnv, args...) = run(DynamicStyle(env), NumAgentStyle(env), agent, env, args...) function run( ::Sequential, + ::SingleAgent, agent::AbstractAgent, env::AbstractEnv, stop_condition, @@ -11,34 +12,31 @@ function run( ) reset!(env) - obs = observe(env) - agent(PRE_EPISODE_STAGE, obs) - hook(PRE_EPISODE_STAGE, agent, env, obs) - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + agent(PRE_EPISODE_STAGE, env) + hook(PRE_EPISODE_STAGE, agent, env) + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) while true env(action) - obs = observe(env) - agent(POST_ACT_STAGE, obs) - hook(POST_ACT_STAGE, agent, env, obs) + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) - if get_terminal(obs) - agent(POST_EPISODE_STAGE, obs) # let the agent see the last observation - hook(POST_EPISODE_STAGE, agent, env, obs) + if get_terminal(env) + agent(POST_EPISODE_STAGE, env) # let the agent see the last observation + hook(POST_EPISODE_STAGE, agent, env) - stop_condition(agent, env, obs) && break + stop_condition(agent, env) && break reset!(env) - obs = observe(env) - agent(PRE_EPISODE_STAGE, obs) - hook(PRE_EPISODE_STAGE, agent, env, obs) - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + agent(PRE_EPISODE_STAGE, env) + hook(PRE_EPISODE_STAGE, agent, env) + action = agent(PRE_ACT_STAGE) + hook(PRE_ACT_STAGE, agent, env, action) else - stop_condition(agent, env, obs) && break - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + stop_condition(agent, env) && break + action = agent(PRE_ACT_STAGE) + hook(PRE_ACT_STAGE, agent, env, action) end end hook @@ -46,6 +44,7 @@ end function run( ::Sequential, + ::SingleAgent, agent::AbstractAgent, env::MultiThreadEnv, stop_condition, @@ -54,17 +53,15 @@ function run( while true reset!(env) - obs = observe(env) - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + action = agent(PRE_ACT_STAGE, env) + hook(PRE_ACT_STAGE, agent, env, action) env(action) - obs = observe(env) - agent(POST_ACT_STAGE, obs) - hook(POST_ACT_STAGE, agent, env, obs) + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) - if stop_condition(agent, env, obs) - agent(PRE_ACT_STAGE, obs) # let the agent see the last observation + if stop_condition(agent, env) + agent(PRE_ACT_STAGE, env) # let the agent see the last observation break end end @@ -73,22 +70,23 @@ end function run( ::Sequential, + ::MultiAgent, agents::Tuple{Vararg{<:AbstractAgent}}, env::AbstractEnv, stop_condition, hooks = [EmptyHook() for _ in agents], ) - reset!(env) - observations = [observe(env, get_role(agent)) for agent in agents] + @assert length(agents) == get_num_players(env) - valid_action = rand(get_action_space(env)) # init with a dummy value + reset!(env) + valid_action = rand(get_actions(env)) # init with a dummy value # async here? - for (agent, obs, hook) in zip(agents, observations, hooks) - agent(PRE_EPISODE_STAGE, obs) - hook(PRE_EPISODE_STAGE, agent, env, obs) - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + for (agent, hook) in zip(agents, hooks) + agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(PRE_EPISODE_STAGE, agent, env) + action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(PRE_ACT_STAGE, agent, env, action) # for Sequential environments, only one action is valid if get_current_player(env) == get_role(agent) valid_action = action @@ -98,41 +96,35 @@ function run( while true env(valid_action) - observations = [observe(env, get_role(agent)) for agent in agents] - - for (agent, obs, hook) in zip(agents, observations, hooks) - agent(POST_ACT_STAGE, obs) - hook(POST_ACT_STAGE, agent, env, obs) + for (agent, hook) in zip(agents, hooks) + agent(POST_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(POST_ACT_STAGE, agent, env) end - if get_terminal(observations[1]) - for (agent, obs, hook) in zip(agents, observations, hooks) - agent(POST_EPISODE_STAGE, obs) - hook(POST_EPISODE_STAGE, agent, env, obs) + if get_terminal(env) + for (agent, hook) in zip(agents, hooks) + agent(POST_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(POST_EPISODE_STAGE, agent, env) end - stop_condition(agents, env, observations) && break - + stop_condition(agents, env) && break reset!(env) - - observations = [observe(env, get_role(agent)) for agent in agents] - # async here? - for (agent, obs, hook) in zip(agents, observations, hooks) - agent(PRE_EPISODE_STAGE, obs) - hook(PRE_EPISODE_STAGE, agent, env, obs) - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + for (agent, hook) in zip(agents, hooks) + agent(PRE_EPISODE_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(PRE_EPISODE_STAGE, agent, env) + action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(PRE_ACT_STAGE, agent, env, action) # for Sequential environments, only one action is valid if get_current_player(env) == get_role(agent) valid_action = action end end else - stop_condition(agents, env, observations) && break - for (agent, obs, hook) in zip(agents, observations, hooks) - action = agent(PRE_ACT_STAGE, obs) - hook(PRE_ACT_STAGE, agent, env, obs, action) + stop_condition(agents, env) && break + for (agent, hook) in zip(agents, hooks) + action = agent(PRE_ACT_STAGE, SubjectiveEnv(env, get_role(agent))) + hook(PRE_ACT_STAGE, agent, env, action) # for Sequential environments, only one action is valid if get_current_player(env) == get_role(agent) valid_action = action @@ -142,53 +134,3 @@ function run( end hooks end - -function run( - ::Simultaneous, - agents::Tuple{Vararg{<:AbstractAgent}}, - env::AbstractEnv, - stop_condition, - hook = EmptyHook(), -) - reset!(env) - observations = [observe(env, get_role(agent)) for agent in agents] - hook(PRE_EPISODE_STAGE, agents, env, observations) - actions = [agent(obs) for obs in observations] - hook(PRE_ACT_STAGE, agents, env, observations, actions) - - while true - env(actions) - - for (i, agent) in enumerate(agents) - observations[i] = observe(env, get_role(agent)) - end - hook(POST_ACT_STAGE, agents, env, observations) - - if get_terminal(observations[1]) - for (agent, obs) in zip(agents, observations) - agent(POST_EPISODE_STAGE, obs) - end - hook(POST_EPISODE_STAGE, agents, env, observations) - - stop_condition(agents, env, observations) && break - - reset!(env) - - for (i, agent) in enumerate(agents) - observations[i] = observe(env, get_role(agent)) - end - hook(PRE_EPISODE_STAGE, agents, env, observations) - for (i, agent) in enumerate(agents) - actions[i] = agent(observations[i]) - end - hook(PRE_ACT_STAGE, agents, env, observations, actions) - else - stop_condition(agents, env, observations) && break - for (i, agent) in enumerate(agents) - actions[i] = agent(observations[i]) - end - hook(PRE_ACT_STAGE, agents, env, observations, actions) - end - end - hook -end diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index 4671050..0caa56c 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -29,7 +29,7 @@ end # StopAfterStep ##### """ - StopAfterStep(step; cur = 1, is_show_progress = true, tag = "TRAINING") + StopAfterStep(step; cur = 1, is_show_progress = true) Return `true` after being called `step` times. """ @@ -37,17 +37,16 @@ mutable struct StopAfterStep{Tl} step::Int cur::Int progress::Tl - tag::String end -function StopAfterStep(step; cur = 1, is_show_progress = true, tag = "TRAINING") +function StopAfterStep(step; cur = 1, is_show_progress = true) if is_show_progress progress = Progress(step) ProgressMeter.update!(progress, cur) else progress = nothing end - StopAfterStep(step, cur, progress, tag) + StopAfterStep(step, cur, progress) end function (s::StopAfterStep)(args...) @@ -69,7 +68,7 @@ end ##### """ - StopAfterEpisode(episode; cur = 0, is_show_progress = true, tag = "TRAINING") + StopAfterEpisode(episode; cur = 0, is_show_progress = true) Return `true` after being called `episode`. If `is_show_progress` is `true`, the `ProgressMeter` will be used to show progress. """ @@ -77,29 +76,22 @@ mutable struct StopAfterEpisode{Tl} episode::Int cur::Int progress::Tl - tag::String end -function StopAfterEpisode(episode; cur = 0, is_show_progress = true, tag = "TRAINING") +function StopAfterEpisode(episode; cur = 0, is_show_progress = true) if is_show_progress progress = Progress(episode) ProgressMeter.update!(progress, cur) else progress = nothing end - StopAfterEpisode(episode, cur, progress, tag) + StopAfterEpisode(episode, cur, progress) end -function (s::StopAfterEpisode)(agent, env, obs) - @debug s.tag EPISODE = s.cur - - is_terminal = length(get_players(env)) == 1 ? get_terminal(obs) : get_terminal(obs[1]) - - if is_terminal +function (s::StopAfterEpisode)(agent, env) + if get_terminal(env) s.cur += 1 if !isnothing(s.progress) - # https://github.com/timholy/ProgressMeter.jl/pull/131 - # next!(s.progress; showvalues = [(Symbol(s.tag, "/", :EPISODE), s.cur)]) next!(s.progress;) end end @@ -107,8 +99,7 @@ function (s::StopAfterEpisode)(agent, env, obs) s.cur >= s.episode end -(s::StopAfterEpisode)(agent, env::MultiThreadEnv, obs::BatchObs) = - @error "MultiThreadEnv is not supported!" +(s::StopAfterEpisode)(agent, env::MultiThreadEnv) = @error "MultiThreadEnv is not supported!" ##### # StopWhenDone @@ -117,8 +108,8 @@ end """ StopWhenDone() -Return `true` if the `terminal` field of an observation is `true`. +Return `true` if the environment is terminated. """ struct StopWhenDone end -(s::StopWhenDone)(agent, env, obs) = get_terminal(obs) +(s::StopWhenDone)(agent, env) = get_terminal(env) diff --git a/src/utils/printing.jl b/src/utils/printing.jl new file mode 100644 index 0000000..cb3713a --- /dev/null +++ b/src/utils/printing.jl @@ -0,0 +1,37 @@ +using AbstractTrees + +const AT = AbstractTrees + +struct StructTree{X} + x::X +end + +AT.children(t::StructTree{X}) where X = Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X)) +AT.children(t::StructTree{<:AbstractArray}) = () +AT.children(t::Pair{Symbol, <:StructTree}) = children(last(t)) +AT.printnode(io::IO, t::StructTree) = summary(io, t.x) + +AT.printnode(io::IO, t::StructTree{<:Number}) = print(io, t.x) + +function AT.printnode(io::IO, t::StructTree{String}) + s = t.x + i = findfirst('\n', s) + if isnothing(i) + if length(s) > 79 + print(io, "\"s[1:79]...\"") + else + print(io, "\"$s\"") + end + else + if i > 79 + print(io, "\"s[1:79]...\"") + else + print(io, "\"$(s[1:i])...\"") + end + end +end + +function AT.printnode(io::IO, t::Pair{Symbol, <:StructTree}) + print(io, first(t), " => ") + AT.printnode(io, last(t)) +end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index c9be436..be1556d 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,3 +1,4 @@ +include("printing.jl") include("base.jl") include("circular_array_buffer.jl") include("device.jl") From c945354aeb23a69ca6a0b099b0f33ec5b82ec4bc Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 20 Jul 2020 17:48:00 +0800 Subject: [PATCH 03/17] switch to CUDA.jl --- .travis.yml | 2 +- Project.toml | 4 +++- src/extensions/{CuArrays.jl => CUDA.jl} | 18 +++++++++--------- src/extensions/Flux.jl | 8 +------- src/extensions/ReinforcementLearningBase.jl | 7 ++----- src/extensions/extensions.jl | 2 +- src/utils/device.jl | 2 +- test/runtests.jl | 3 +-- 8 files changed, 19 insertions(+), 27 deletions(-) rename src/extensions/{CuArrays.jl => CUDA.jl} (76%) diff --git a/.travis.yml b/.travis.yml index c898dd9..ed7b5ec 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,4 +19,4 @@ after_success: ## uncomment the following lines to override the default test script script: - - travis_wait 50 julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()' + - travis_wait 50 julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test(coverage=true)' diff --git a/Project.toml b/Project.toml index 4007c38..83a73ce 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,13 @@ version = "0.3.3" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -31,7 +33,7 @@ ImageTransformations = "0.8" JLD = "0.10" MacroTools = "0.5" ProgressMeter = "1.2" -ReinforcementLearningBase = "0.7" +ReinforcementLearningBase = "0.8" Setfield = "0.6" StatsBase = "0.32, 0.33" julia = "1.3" diff --git a/src/extensions/CuArrays.jl b/src/extensions/CUDA.jl similarity index 76% rename from src/extensions/CuArrays.jl rename to src/extensions/CUDA.jl index 192f4c2..fe1fe3f 100644 --- a/src/extensions/CuArrays.jl +++ b/src/extensions/CUDA.jl @@ -1,5 +1,5 @@ -using GPUArrays, CuArrays, FillArrays -using CuArrays: threadIdx, blockIdx, blockDim +using CUDA, FillArrays +using CUDA: threadIdx, blockIdx, blockDim ##### # Cartesian indexing of CuArray @@ -16,8 +16,8 @@ function Base.getindex(xs::CuArray{T,N}, indices::CuArray{CartesianIndex{N}}) wh num_blocks = ceil(Int, n / num_threads) function kernel( - ys::CuArrays.CuDeviceArray{T}, - xs::CuArrays.CuDeviceArray{T}, + ys::CUDA.CuDeviceArray{T}, + xs::CUDA.CuDeviceArray{T}, indices, ) i = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -30,7 +30,7 @@ function Base.getindex(xs::CuArray{T,N}, indices::CuArray{CartesianIndex{N}}) wh return end - CuArrays.@cuda blocks = num_blocks threads = num_threads kernel(ys, xs, indices) + CUDA.@cuda blocks = num_blocks threads = num_threads kernel(ys, xs, indices) end return ys @@ -48,7 +48,7 @@ function Base.setindex!( num_threads = min(n, 256) num_blocks = ceil(Int, n / num_threads) - function kernel(xs::CuArrays.CuDeviceArray{T}, indices, v) + function kernel(xs::CUDA.CuDeviceArray{T}, indices, v) i = threadIdx().x + (blockIdx().x - 1) * blockDim().x if i <= length(indices) @@ -59,7 +59,7 @@ function Base.setindex!( return end - CuArrays.@cuda blocks = num_blocks threads = num_threads kernel(xs, indices, v) + CUDA.@cuda blocks = num_blocks threads = num_threads kernel(xs, indices, v) end return v end @@ -75,7 +75,7 @@ function Base.setindex!( num_threads = min(n, 256) num_blocks = ceil(Int, n / num_threads) - function kernel(xs::CuArrays.CuDeviceArray{T}, indices, v) + function kernel(xs::CUDA.CuDeviceArray{T}, indices, v) i = threadIdx().x + (blockIdx().x - 1) * blockDim().x if i <= length(indices) @@ -86,7 +86,7 @@ function Base.setindex!( return end - CuArrays.@cuda blocks = num_blocks threads = num_threads kernel(xs, indices, v) + CUDA.@cuda blocks = num_blocks threads = num_threads kernel(xs, indices, v) end return v end diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index 45abcac..039374d 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -5,16 +5,12 @@ import Flux: glorot_uniform, glorot_normal using Random using LinearAlgebra +# watch https://github.com/FluxML/Flux.jl/issues/1274 glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(Flux.nfan(dims...))) glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(Flux.nfan(dims...))) -seed_glorot_uniform(; seed = nothing) = - (dims...) -> glorot_uniform(MersenneTwister(seed), dims...) -seed_glorot_normal(; seed = nothing) = - (dims...) -> glorot_normal(MersenneTwister(seed), dims...) - # https://github.com/FluxML/Flux.jl/pull/1171/ # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/Orthogonal function orthogonal_matrix(rng::AbstractRNG, nrow, ncol) @@ -31,5 +27,3 @@ function orthogonal(rng::AbstractRNG, d1, rest_dims...) end orthogonal(dims...) = orthogonal(Random.GLOBAL_RNG, dims...) - -seed_orthogonal(; seed = nothing) = (dims...) -> orthogonal(MersenneTwister(seed), dims...) diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index bb0d3ca..b0633e1 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -1,13 +1,10 @@ -using CuArrays +using CUDA using Distributions: pdf using Random using Flux using BSON -RLBase.get_prob(p::AbstractPolicy, obs, ::RLBase.AbstractActionStyle, a) = - pdf(get_prob(p, obs), a) - -Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CuArrays.CURAND.generator(), s) +Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s) # avoid fallback silently Flux.testmode!(p::AbstractPolicy, mode = true) = diff --git a/src/extensions/extensions.jl b/src/extensions/extensions.jl index 8b3afed..3fab88b 100644 --- a/src/extensions/extensions.jl +++ b/src/extensions/extensions.jl @@ -1,4 +1,4 @@ include("Flux.jl") -include("CuArrays.jl") +include("CUDA.jl") include("Zygote.jl") include("ReinforcementLearningBase.jl") diff --git a/src/utils/device.jl b/src/utils/device.jl index 96e9b14..ee121c1 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -1,7 +1,7 @@ export device, send_to_host, send_to_device using Flux -using CuArrays +using CUDA using Adapt send_to_host(x) = send_to_device(Val(:cpu), x) diff --git a/test/runtests.jl b/test/runtests.jl index 7cfd8f6..f95166a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,10 +7,9 @@ using Distributions: probs using ReinforcementLearningEnvironments using Flux using Zygote -using CUDAapi if has_cuda() - using CuArrays + using CUDA end @testset "ReinforcementLearningCore.jl" begin From ef7fb07b4915c05efa5dd51532df9b8390fc9a80 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 22 Jul 2020 12:30:03 +0800 Subject: [PATCH 04/17] fix more --- src/components/approximators/neural_network_approximator.jl | 6 +++--- src/components/policies/tabular_policy.jl | 1 + src/extensions/Flux.jl | 2 +- src/utils/device.jl | 3 +++ src/utils/printing.jl | 2 ++ 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/components/approximators/neural_network_approximator.jl b/src/components/approximators/neural_network_approximator.jl index e026e1b..6f3307f 100644 --- a/src/components/approximators/neural_network_approximator.jl +++ b/src/components/approximators/neural_network_approximator.jl @@ -10,11 +10,11 @@ Use a DNN model for value estimation. # Keyword arguments - `model`, a Flux based DNN model. -- `optimizer=Descent()` +- `optimizer=nothing` """ Base.@kwdef struct NeuralNetworkApproximator{M,O} <: AbstractApproximator model::M - optimizer::O = Descent() + optimizer::O = nothing end (app::NeuralNetworkApproximator)(x) = app.model(x) @@ -42,7 +42,7 @@ Flux.testmode!(app::NeuralNetworkApproximator, mode = true) = testmode!(app.mode The `actor` part must return logits (*Do not use softmax in the last layer!*), and the `critic` part must return a state value. """ -Base.@kwdef struct ActorCritic{A,C,O} +Base.@kwdef struct ActorCritic{A,C,O} <: AbstractApproximator actor::A critic::C optimizer::O = ADAM() diff --git a/src/components/policies/tabular_policy.jl b/src/components/policies/tabular_policy.jl index a22c742..2da45d9 100644 --- a/src/components/policies/tabular_policy.jl +++ b/src/components/policies/tabular_policy.jl @@ -2,6 +2,7 @@ export TabularPolicy using AbstractTrees +## TODO: Use TabularApproximator struct TabularPolicy{S,F,E} <: RLBase.AbstractPolicy probs::Dict{S,Vector{Float64}} key::F diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index 039374d..69aceb5 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -1,4 +1,4 @@ -export seed_glorot_normal, seed_glorot_uniform, seed_orthogonal +export orthogonal import Flux: glorot_uniform, glorot_normal diff --git a/src/utils/device.jl b/src/utils/device.jl index ee121c1..f075076 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -33,3 +33,6 @@ function device(x::Union{Tuple,NamedTuple}) d1 end end + +# recoganize Torch.jl +# device(x::Tensor) = Val(Symbol(:gpu, x.device)) \ No newline at end of file diff --git a/src/utils/printing.jl b/src/utils/printing.jl index cb3713a..0dedc74 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -1,3 +1,5 @@ +export StructTree + using AbstractTrees const AT = AbstractTrees From de8aec4eff474c40be5ed12563f56dc3b82e381b Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 22 Jul 2020 14:38:19 +0800 Subject: [PATCH 05/17] fix tests --- Project.toml | 1 + src/utils/device.jl | 2 ++ test/components/agents.jl | 17 ----------------- test/components/preprocessors.jl | 11 +++++------ test/core/core.jl | 2 +- test/runtests.jl | 5 +---- test/utils/device.jl | 10 ++++------ 7 files changed, 14 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index 83a73ce..087e06d 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" +ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/utils/device.jl b/src/utils/device.jl index f075076..7882a01 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -4,6 +4,8 @@ using Flux using CUDA using Adapt +import CUDA:device + send_to_host(x) = send_to_device(Val(:cpu), x) send_to_device(::Val{:cpu}, x) = x # cpu(x) is not very efficient! So by default we do nothing here. diff --git a/test/components/agents.jl b/test/components/agents.jl index 71cc747..876db55 100644 --- a/test/components/agents.jl +++ b/test/components/agents.jl @@ -5,23 +5,6 @@ trajectory = VectorialCompactSARTSATrajectory(), ) - obs1 = (state = 1,) - agent(PRE_EPISODE_STAGE, obs1) - a1 = agent(PRE_ACT_STAGE, obs1) - @test a1 ∈ action_space - - obs2 = (reward = 1.0, terminal = true, state = 2) - agent(POST_ACT_STAGE, obs2) - dummy_action = agent(POST_EPISODE_STAGE, obs2) - - @test length(agent.trajectory) == 1 - @test get_trace(agent.trajectory, :state) == [1] - @test get_trace(agent.trajectory, :action) == [a1] - @test get_trace(agent.trajectory, :reward) == [get_reward(obs2)] - @test get_trace(agent.trajectory, :terminal) == [get_terminal(obs2)] - @test get_trace(agent.trajectory, :next_state) == [get_state(obs2)] - @test get_trace(agent.trajectory, :next_action) == [dummy_action] - @testset "loading/saving Agent" begin mktempdir() do dir RLCore.save(dir, agent) diff --git a/test/components/preprocessors.jl b/test/components/preprocessors.jl index 291e8d7..8108242 100644 --- a/test/components/preprocessors.jl +++ b/test/components/preprocessors.jl @@ -1,9 +1,9 @@ @testset "preprocessors" begin @testset "ResizeImage" begin - obs = (state = ones(4, 4),) + state = ones(4, 4) p = ResizeImage(2, 2) - @test get_state(p(obs)) == ones(2, 2) + @test p(state) == ones(2, 2) end @testset "StackFrames" begin @@ -11,12 +11,11 @@ p = StackFrames(2, 2, 3) for i in 1:3 - obs = (state = A * i,) - p(obs) + p(A * i) end - obs = (state = A * 4,) - @test get_state(p(obs)) == reshape(repeat([2, 3, 4]; inner = 4), 2, 2, :) + state = A * 4 + @test p(state) == reshape(repeat([2, 3, 4]; inner = 4), 2, 2, :) end diff --git a/test/core/core.jl b/test/core/core.jl index 4606baa..d9f5a29 100644 --- a/test/core/core.jl +++ b/test/core/core.jl @@ -1,5 +1,5 @@ @testset "simple workflow" begin - env = WrappedEnv(CloneStatePreprocessor(), CartPoleEnv{Float32}()) + env = CartPoleEnv{Float32}() |> StateOverriddenEnv(;deep_copy_state=deepcopy) agent = Agent(; policy = RandomPolicy(env), trajectory = VectorialCompactSARTSATrajectory(; state_type = Vector{Float32}), diff --git a/test/runtests.jl b/test/runtests.jl index f95166a..b71ad52 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,10 +7,7 @@ using Distributions: probs using ReinforcementLearningEnvironments using Flux using Zygote - -if has_cuda() - using CUDA -end +using CUDA @testset "ReinforcementLearningCore.jl" begin include("core/core.jl") diff --git a/test/utils/device.jl b/test/utils/device.jl index 8420d6a..8bde3e8 100644 --- a/test/utils/device.jl +++ b/test/utils/device.jl @@ -5,11 +5,9 @@ @test device(Conv((2, 2), 1 => 16, relu)) == Val(:cpu) @test device(Chain(x -> x .^ 2, Dense(2, 3))) == Val(:cpu) - if has_cuda() - @test device(rand(2) |> gpu) == Val(:gpu) - @test device(Dense(2, 3) |> gpu) == Val(:gpu) - @test device(Conv((2, 2), 1 => 16, relu) |> gpu) == Val(:gpu) - @test device(Chain(x -> x .^ 2, Dense(2, 3)) |> gpu) == Val(:gpu) - end + @test device(rand(2) |> gpu) == Val(:gpu) + @test device(Dense(2, 3) |> gpu) == Val(:gpu) + @test device(Conv((2, 2), 1 => 16, relu) |> gpu) == Val(:gpu) + @test device(Chain(x -> x .^ 2, Dense(2, 3)) |> gpu) == Val(:gpu) end From eaf2ae049b5bdf59ce709371dc7de29a52b80fef Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 22 Jul 2020 15:53:25 +0800 Subject: [PATCH 06/17] better printing --- src/ReinforcementLearningCore.jl | 2 +- src/components/agents/abstract_agent.jl | 3 +++ src/components/trajectories/abstract_trajectory.jl | 7 +++++++ src/components/trajectories/circular_trajectory.jl | 1 + src/extensions/ReinforcementLearningBase.jl | 4 ++++ src/utils/printing.jl | 3 ++- 6 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/ReinforcementLearningCore.jl b/src/ReinforcementLearningCore.jl index 2106842..fdca2e6 100644 --- a/src/ReinforcementLearningCore.jl +++ b/src/ReinforcementLearningCore.jl @@ -11,8 +11,8 @@ provides some standard and reusable components defined by [**RLBase**](https://g export RLCore -include("extensions/extensions.jl") include("utils/utils.jl") +include("extensions/extensions.jl") include("components/components.jl") include("core/core.jl") diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl index 4e23019..807aea3 100644 --- a/src/components/agents/abstract_agent.jl +++ b/src/components/agents/abstract_agent.jl @@ -73,3 +73,6 @@ struct Training{T<:AbstractStage} end Training(s::T) where {T<:AbstractStage} = Training{T}() struct Testing{T<:AbstractStage} end Testing(s::T) where {T<:AbstractStage} = Testing{T}() + +Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent)) +Base.summary(io::IO, agent::T) where {T<:AbstractAgent} = print(io, T.name) diff --git a/src/components/trajectories/abstract_trajectory.jl b/src/components/trajectories/abstract_trajectory.jl index f7971b6..ed3150b 100644 --- a/src/components/trajectories/abstract_trajectory.jl +++ b/src/components/trajectories/abstract_trajectory.jl @@ -90,3 +90,10 @@ end function Base.pop!(t::AbstractTrajectory, s::Symbol...) NamedTuple{s}(pop!(t, x) for x in s) end + +function AbstractTrees.children(t::StructTree{<:AbstractTrajectory}) + traces = get_trace(t.x) + Tuple(k => StructTree(v) for (k,v) in pairs(traces)) +end + +Base.summary(io::IO, t::T) where {T<:AbstractTrajectory} = print(io, "$(length(t))-element $(T.name)") diff --git a/src/components/trajectories/circular_trajectory.jl b/src/components/trajectories/circular_trajectory.jl index 9368881..37eb5df 100644 --- a/src/components/trajectories/circular_trajectory.jl +++ b/src/components/trajectories/circular_trajectory.jl @@ -6,6 +6,7 @@ const CircularTrajectory = Trajectory{ NamedTuple{names,trace_types}, } where {names,types,trace_types<:Tuple{Vararg{<:CircularArrayBuffer}}} + """ CircularTrajectory(; capacity, trace_name=eltype=>size...) diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index b0633e1..cabfec2 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -3,6 +3,8 @@ using Distributions: pdf using Random using Flux using BSON +using AbstractTrees + Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s) @@ -10,6 +12,8 @@ Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), Flux.testmode!(p::AbstractPolicy, mode = true) = @error "someone forgets to implement this method!!!" +Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p)) + function save(f::String, p::AbstractPolicy) policy = cpu(p) BSON.@save f policy diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 0dedc74..e4f5397 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -13,7 +13,8 @@ AT.children(t::StructTree{<:AbstractArray}) = () AT.children(t::Pair{Symbol, <:StructTree}) = children(last(t)) AT.printnode(io::IO, t::StructTree) = summary(io, t.x) -AT.printnode(io::IO, t::StructTree{<:Number}) = print(io, t.x) +AT.printnode(io::IO, t::StructTree{<:Union{Number,Symbol}}) = print(io, t.x) + function AT.printnode(io::IO, t::StructTree{String}) s = t.x From 92d41a78aac3b2d72327953de743bb8e49ba705a Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 22 Jul 2020 18:39:24 +0800 Subject: [PATCH 07/17] rename obs -> env --- src/components/agents/abstract_agent.jl | 10 +-- src/components/agents/agent.jl | 68 +++++++++---------- src/components/agents/dyna_agent.jl | 18 ++--- .../approximators/abstract_approximator.jl | 2 +- src/components/learners/learners.jl | 4 +- src/components/policies/Q_based_policy.jl | 20 +++--- src/components/policies/V_based_policy.jl | 36 +++++----- .../policies/chance_player_policy.jl | 4 +- src/components/policies/off_policy.jl | 2 +- 9 files changed, 82 insertions(+), 82 deletions(-) diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl index 807aea3..62b69ba 100644 --- a/src/components/agents/abstract_agent.jl +++ b/src/components/agents/abstract_agent.jl @@ -16,8 +16,8 @@ export AbstractAgent, Testing """ - (agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -> action - (agent::AbstractAgent)(stage::AbstractStage, obs) + (agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env) -> action + (agent::AbstractAgent)(stage::AbstractStage, env) Similar to [`AbstractPolicy`](@ref), an agent is also a functional object which takes in an observation and returns an action. The main difference is that, we divide an experiment into the following stages: @@ -43,7 +43,7 @@ PRE_EXPERIMENT_STAGE | PRE_ACT_STAGE POST_ACT_STAGE | | | | | | v | +-----+ v +-------+ v +-----+ | v --------------------->+ env +------>+ agent +------->+ env +---> ... ------->...... - | ^ +-----+ obs +-------+ action +-----+ ^ | + | ^ +-----+ +-------+ action +-----+ ^ | | | | | | +--PRE_EPISODE_STAGE POST_EPISODE_STAGE----+ | | | @@ -66,8 +66,8 @@ const POST_EPISODE_STAGE = PostEpisodeStage() const PRE_ACT_STAGE = PreActStage() const POST_ACT_STAGE = PostActStage() -(agent::AbstractAgent)(obs) = agent(PRE_ACT_STAGE, obs) -function (agent::AbstractAgent)(stage::AbstractStage, obs) end +(agent::AbstractAgent)(env) = agent(PRE_ACT_STAGE, env) +function (agent::AbstractAgent)(stage::AbstractStage, env) end struct Training{T<:AbstractStage} end Training(s::T) where {T<:AbstractStage} = Training{T}() diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index b8da1e8..8c618d2 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -26,7 +26,7 @@ Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: A end # avoid polluting trajectory -(agent::Agent)(obs) = agent.policy(obs) +(agent::Agent)(env) = agent.policy(env) Flux.functor(x::Agent) = (policy = x.policy,), y -> @set x.policy = y.policy @@ -69,25 +69,25 @@ function Flux.testmode!(agent::Agent, mode = true) testmode!(agent.policy, mode) end -(agent::Agent)(stage::AbstractStage, obs) = - agent.is_training ? agent(Training(stage), obs) : agent(Testing(stage), obs) +(agent::Agent)(stage::AbstractStage, env) = + agent.is_training ? agent(Training(stage), env) : agent(Testing(stage), env) -(agent::Agent)(::Testing, obs) = nothing -(agent::Agent)(::Testing{PreActStage}, obs) = agent.policy(obs) +(agent::Agent)(::Testing, env) = nothing +(agent::Agent)(::Testing{PreActStage}, env) = agent.policy(env) ##### # DummyTrajectory ##### -(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::AbstractStage, obs) = nothing -(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::PreActStage, obs) = agent.policy(obs) +(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::AbstractStage, env) = nothing +(agent::Agent{<:AbstractPolicy, <:DummyTrajectory})(stage::PreActStage, env) = agent.policy(env) ##### # EpisodicCompactSARTSATrajectory ##### function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::Training{PreEpisodeStage}, - obs, + env, ) empty!(agent.trajectory) nothing @@ -95,28 +95,28 @@ end function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::Training{PreActStage}, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.policy, agent.trajectory) action end function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::Training{PostActStage}, - obs, + env, ) - push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs)) + push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) nothing end function (agent::Agent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::Training{PostEpisodeStage}, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.policy, agent.trajectory) action end @@ -132,7 +132,7 @@ function ( } )( ::Training{PreEpisodeStage}, - obs, + env, ) if length(agent.trajectory) > 0 pop!(agent.trajectory, :state, :action) @@ -147,10 +147,10 @@ function ( } )( ::Training{PreActStage}, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.policy, agent.trajectory) action end @@ -162,9 +162,9 @@ function ( } )( ::Training{PostActStage}, - obs, + env, ) - push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs)) + push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) nothing end @@ -175,10 +175,10 @@ function ( } )( ::Training{PostEpisodeStage}, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.policy, agent.trajectory) action end @@ -189,7 +189,7 @@ end function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( ::Training{PreEpisodeStage}, - obs, + env, ) if length(agent.trajectory) > 0 pop!(agent.trajectory, :state, :action) @@ -199,28 +199,28 @@ end function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( ::Training{PreActStage}, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.policy, agent.trajectory) action end function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( ::Training{PostActStage}, - obs, + env, ) - push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs)) + push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) nothing end function (agent::Agent{<:AbstractPolicy,<:VectorialCompactSARTSATrajectory})( ::Training{PostEpisodeStage}, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.policy, agent.trajectory) action end diff --git a/src/components/agents/dyna_agent.jl b/src/components/agents/dyna_agent.jl index 5c60eb4..c0fd2d9 100644 --- a/src/components/agents/dyna_agent.jl +++ b/src/components/agents/dyna_agent.jl @@ -35,7 +35,7 @@ get_role(agent::DynaAgent) = agent.role function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::PreEpisodeStage, - obs, + env, ) empty!(agent.trajectory) nothing @@ -43,10 +43,10 @@ end function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::PreActStage, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.model, agent.trajectory, agent.policy) # model learning update!(agent.policy, agent.trajectory) # direct learning update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning @@ -55,18 +55,18 @@ end function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::PostActStage, - obs, + env, ) - push!(agent.trajectory; reward = get_reward(obs), terminal = get_terminal(obs)) + push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env)) nothing end function (agent::DynaAgent{<:AbstractPolicy,<:EpisodicCompactSARTSATrajectory})( ::PostEpisodeStage, - obs, + env, ) - action = agent.policy(obs) - push!(agent.trajectory; state = get_state(obs), action = action) + action = agent.policy(env) + push!(agent.trajectory; state = get_state(env), action = action) update!(agent.model, agent.trajectory, agent.policy) # model learning update!(agent.policy, agent.trajectory) # direct learning update!(agent.policy, agent.model, agent.trajectory, agent.plan_step) # policy learning diff --git a/src/components/approximators/abstract_approximator.jl b/src/components/approximators/abstract_approximator.jl index 1d23f58..4489c30 100644 --- a/src/components/approximators/abstract_approximator.jl +++ b/src/components/approximators/abstract_approximator.jl @@ -2,7 +2,7 @@ export AbstractApproximator, ApproximatorStyle, Q_APPROXIMATOR, QApproximator, V_APPROXIMATOR, VApproximator """ - (app::AbstractApproximator)(obs) + (app::AbstractApproximator)(env) An approximator is a functional object for value estimation. It serves as a black box to provides an abstraction over different diff --git a/src/components/learners/learners.jl b/src/components/learners/learners.jl index 0ec808c..4071182 100644 --- a/src/components/learners/learners.jl +++ b/src/components/learners/learners.jl @@ -3,13 +3,13 @@ export AbstractLearner, extract_experience using Flux """ - (learner::AbstractLearner)(obs) + (learner::AbstractLearner)(env) A learner is usually used to estimate state values, state-action values or distributional values based on experiences. """ abstract type AbstractLearner end -function (learner::AbstractLearner)(obs) end +function (learner::AbstractLearner)(env) end """ get_priority(p::AbstractLearner, experience) diff --git a/src/components/policies/Q_based_policy.jl b/src/components/policies/Q_based_policy.jl index 96dc750..45f6a06 100644 --- a/src/components/policies/Q_based_policy.jl +++ b/src/components/policies/Q_based_policy.jl @@ -18,16 +18,16 @@ end Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learner -(π::QBasedPolicy)(obs) = π(obs, ActionStyle(obs)) -(π::QBasedPolicy)(obs, ::MinimalActionSet) = obs |> π.learner |> π.explorer -(π::QBasedPolicy)(obs, ::FullActionSet) = - π.explorer(π.learner(obs), get_legal_actions_mask(obs)) - -RLBase.get_prob(p::QBasedPolicy, obs) = get_prob(p, obs, ActionStyle(obs)) -RLBase.get_prob(p::QBasedPolicy, obs, ::MinimalActionSet) = - get_prob(p.explorer, p.learner(obs)) -RLBase.get_prob(p::QBasedPolicy, obs, ::FullActionSet) = - get_prob(p.explorer, p.learner(obs), get_legal_actions_mask(obs)) +(π::QBasedPolicy)(env) = π(env, ActionStyle(env)) +(π::QBasedPolicy)(env, ::MinimalActionSet) = env |> π.learner |> π.explorer +(π::QBasedPolicy)(env, ::FullActionSet) = + π.explorer(π.learner(env), get_legal_actions_mask(env)) + +RLBase.get_prob(p::QBasedPolicy, env) = get_prob(p, env, ActionStyle(env)) +RLBase.get_prob(p::QBasedPolicy, env, ::MinimalActionSet) = + get_prob(p.explorer, p.learner(env)) +RLBase.get_prob(p::QBasedPolicy, env, ::FullActionSet) = + get_prob(p.explorer, p.learner(env), get_legal_actions_mask(env)) @forward QBasedPolicy.learner RLBase.get_priority, RLBase.update! diff --git a/src/components/policies/V_based_policy.jl b/src/components/policies/V_based_policy.jl index f9ef2a0..2f36c2d 100644 --- a/src/components/policies/V_based_policy.jl +++ b/src/components/policies/V_based_policy.jl @@ -8,7 +8,7 @@ using MacroTools: @forward # Key words & Fields - `learner`::[`AbstractLearner`](@ref), learn how to estimate state values. -- `mapping`, a customized function `(obs, learner) -> action_values` +- `mapping`, a customized function `(env, learner) -> action_values` - `explorer`::[`AbstractExplorer`](@ref), decide which action to take based on action values. """ Base.@kwdef struct VBasedPolicy{L<:AbstractLearner,M,E<:AbstractExplorer} <: AbstractPolicy @@ -17,31 +17,31 @@ Base.@kwdef struct VBasedPolicy{L<:AbstractLearner,M,E<:AbstractExplorer} <: Abs explorer::E = GreedyExplorer() end -(p::VBasedPolicy)(obs) = p(obs, ActionStyle(obs)) +(p::VBasedPolicy)(env) = p(env, ActionStyle(env)) -(p::VBasedPolicy)(obs, ::MinimalActionSet) = p.mapping(obs, p.learner) |> p.explorer +(p::VBasedPolicy)(env, ::MinimalActionSet) = p.mapping(env, p.learner) |> p.explorer -function (p::VBasedPolicy)(obs, ::FullActionSet) - action_values = p.mapping(obs, p.learner) - p.explorer(action_values, get_legal_actions_mask(obs)) +function (p::VBasedPolicy)(env, ::FullActionSet) + action_values = p.mapping(env, p.learner) + p.explorer(action_values, get_legal_actions_mask(env)) end -RLBase.get_prob(p::VBasedPolicy, obs, action::Integer) = - get_prob(p, obs, ActionStyle(obs), action) +RLBase.get_prob(p::VBasedPolicy, env, action::Integer) = + get_prob(p, env, ActionStyle(env), action) -RLBase.get_prob(p::VBasedPolicy, obs, ::MinimalActionSet) = - get_prob(p.explorer, p.mapping(obs, p.learner)) -RLBase.get_prob(p::VBasedPolicy, obs, ::MinimalActionSet, action) = - get_prob(p.explorer, p.mapping(obs, p.learner), action) +RLBase.get_prob(p::VBasedPolicy, env, ::MinimalActionSet) = + get_prob(p.explorer, p.mapping(env, p.learner)) +RLBase.get_prob(p::VBasedPolicy, env, ::MinimalActionSet, action) = + get_prob(p.explorer, p.mapping(env, p.learner), action) -function RLBase.get_prob(p::VBasedPolicy, obs, ::FullActionSet) - action_values = p.mapping(obs, p.learner) - get_prob(p.explorer, action_values, get_legal_actions_mask(obs)) +function RLBase.get_prob(p::VBasedPolicy, env, ::FullActionSet) + action_values = p.mapping(env, p.learner) + get_prob(p.explorer, action_values, get_legal_actions_mask(env)) end -function RLBase.get_prob(p::VBasedPolicy, obs, ::FullActionSet, action) - action_values = p.mapping(obs, p.learner) - get_prob(p.explorer, action_values, get_legal_actions_mask(obs), action) +function RLBase.get_prob(p::VBasedPolicy, env, ::FullActionSet, action) + action_values = p.mapping(env, p.learner) + get_prob(p.explorer, action_values, get_legal_actions_mask(env), action) end @forward VBasedPolicy.learner RLBase.get_priority, RLBase.update! diff --git a/src/components/policies/chance_player_policy.jl b/src/components/policies/chance_player_policy.jl index ae0d701..d756653 100644 --- a/src/components/policies/chance_player_policy.jl +++ b/src/components/policies/chance_player_policy.jl @@ -8,10 +8,10 @@ end ChancePlayerPolicy(;seed=nothing) = ChancePlayerPolicy(MersenneTwister(seed)) -function (p::ChancePlayerPolicy)(obs) +function (p::ChancePlayerPolicy)(env) v = rand(p.rng) s = 0. - for (action, prob) in get_chance_outcome(obs) + for (action, prob) in get_chance_outcome(env) s += prob s >= v && return action end diff --git a/src/components/policies/off_policy.jl b/src/components/policies/off_policy.jl index 0f75bbf..c37c50b 100644 --- a/src/components/policies/off_policy.jl +++ b/src/components/policies/off_policy.jl @@ -10,6 +10,6 @@ Base.@kwdef struct OffPolicy{P,B} <: AbstractPolicy π_behavior::B end -(π::OffPolicy)(obs) = π.π_behavior(obs) +(π::OffPolicy)(env) = π.π_behavior(env) @forward OffPolicy.π_behavior RLBase.get_priority, RLBase.get_prob From fda7ebc2c07605e2bb1ef7ee0d47383ae53650e3 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 24 Jul 2020 01:14:43 +0800 Subject: [PATCH 08/17] sync changes --- Project.toml | 2 +- src/components/agents/abstract_agent.jl | 2 +- src/components/approximators/abstract_approximator.jl | 2 ++ src/components/explorers/UCB_explorer.jl | 4 ++-- src/components/explorers/epsilon_greedy_explorer.jl | 5 ++--- src/components/explorers/gumbel_softmax_explorer.jl | 2 +- src/components/explorers/weighted_explorer.jl | 3 +-- src/components/learners/learners.jl | 2 ++ src/components/policies/chance_player_policy.jl | 2 +- src/core/experiment.jl | 6 +++++- src/core/hooks.jl | 2 ++ src/core/run.jl | 4 ++-- src/extensions/Flux.jl | 8 +++++++- src/extensions/ReinforcementLearningBase.jl | 4 +++- src/utils/device.jl | 9 +++++++++ src/utils/printing.jl | 10 ++++++++-- 16 files changed, 49 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 087e06d..d0df0e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ReinforcementLearningCore" uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6" authors = ["Jun Tian "] -version = "0.3.3" +version = "0.4.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl index 62b69ba..3b4b378 100644 --- a/src/components/agents/abstract_agent.jl +++ b/src/components/agents/abstract_agent.jl @@ -74,5 +74,5 @@ Training(s::T) where {T<:AbstractStage} = Training{T}() struct Testing{T<:AbstractStage} end Testing(s::T) where {T<:AbstractStage} = Testing{T}() -Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent)) +Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),15) Base.summary(io::IO, agent::T) where {T<:AbstractAgent} = print(io, T.name) diff --git a/src/components/approximators/abstract_approximator.jl b/src/components/approximators/abstract_approximator.jl index 4489c30..2aabba6 100644 --- a/src/components/approximators/abstract_approximator.jl +++ b/src/components/approximators/abstract_approximator.jl @@ -10,6 +10,8 @@ kinds of approximate methods (for example DNN provided by Flux or Knet). """ abstract type AbstractApproximator end +Base.summary(io::IO, t::T) where T<:AbstractApproximator = print(io, T.name) + """ update!(a::AbstractApproximator, correction) diff --git a/src/components/explorers/UCB_explorer.jl b/src/components/explorers/UCB_explorer.jl index b9529ad..3110754 100644 --- a/src/components/explorers/UCB_explorer.jl +++ b/src/components/explorers/UCB_explorer.jl @@ -23,8 +23,8 @@ Flux.testmode!(p::UCBExplorer, mode = true) = p.is_training = !mode - `seed`, set the seed of inner RNG. - `is_training=true`, in training mode, time step and counter will not be updated. """ -UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, seed = nothing, is_training = true) = - UCBExplorer(c, fill(ϵ, na), 1, MersenneTwister(seed), is_training) +UCBExplorer(na; c = 2.0, ϵ = 1e-10, step = 1, rng = Random.GLOBAL_RNG, is_training = true) = + UCBExplorer(c, fill(ϵ, na), 1, rng, is_training) @doc raw""" (ucb::UCBExplorer)(values::AbstractArray) diff --git a/src/components/explorers/epsilon_greedy_explorer.jl b/src/components/explorers/epsilon_greedy_explorer.jl index bf1c5fb..85f04da 100644 --- a/src/components/explorers/epsilon_greedy_explorer.jl +++ b/src/components/explorers/epsilon_greedy_explorer.jl @@ -23,7 +23,7 @@ Two kinds of epsilon-decreasing strategy are implmented here (`linear` and `exp` - `decay_steps::Int=0`: the number of steps for epsilon to decay from `ϵ_init` to `ϵ_stable`. - `ϵ_stable::Float64`: the epsilon after `warmup_steps + decay_steps`. - `is_break_tie=false`: randomly select an action of the same maximum values if set to `true`. -- `seed=nothing`: set the seed of internal RNG. +- `rng=Random.GLOBAL_RNG`: set the internal RNG. - `is_training=true`, in training mode, `step` will not be updated. And the `ϵ` will be set to 0. # Example @@ -59,9 +59,8 @@ function EpsilonGreedyExplorer(; step = 1, is_break_tie = false, is_training = true, - seed = nothing, + rng = Random.GLOBAL_RNG ) - rng = MersenneTwister(seed) EpsilonGreedyExplorer{kind,is_break_tie,typeof(rng)}( ϵ_stable, ϵ_init, diff --git a/src/components/explorers/gumbel_softmax_explorer.jl b/src/components/explorers/gumbel_softmax_explorer.jl index 1efea3a..da8a738 100644 --- a/src/components/explorers/gumbel_softmax_explorer.jl +++ b/src/components/explorers/gumbel_softmax_explorer.jl @@ -7,7 +7,7 @@ struct GumbelSoftmaxExplorer <: AbstractExplorer rng::AbstractRNG end -GumbelSoftmaxExplorer(; seed = nothing) = GumbelSoftmaxExplorer(MersenneTwister(seed)) +GumbelSoftmaxExplorer(; rng = Random.GLOBAL_RNG) = GumbelSoftmaxExplorer(rng) function (p::GumbelSoftmaxExplorer)(v::AbstractVector{T}) where {T} logits = logsoftmax(v) diff --git a/src/components/explorers/weighted_explorer.jl b/src/components/explorers/weighted_explorer.jl index c367b88..77ddc5c 100644 --- a/src/components/explorers/weighted_explorer.jl +++ b/src/components/explorers/weighted_explorer.jl @@ -14,8 +14,7 @@ struct WeightedExplorer{T,R<:AbstractRNG} <: AbstractExplorer rng::R end -function WeightedExplorer(; is_normalized::Bool = false, seed = nothing) - rng = MersenneTwister(seed) +function WeightedExplorer(; is_normalized::Bool = false, rng = Random.GLOBAL_RNG) WeightedExplorer{is_normalized,typeof(rng)}(rng) end diff --git a/src/components/learners/learners.jl b/src/components/learners/learners.jl index 4071182..328b858 100644 --- a/src/components/learners/learners.jl +++ b/src/components/learners/learners.jl @@ -9,6 +9,8 @@ A learner is usually used to estimate state values, state-action values or distr """ abstract type AbstractLearner end +Base.summary(io::IO, t::T) where T<:AbstractLearner = print(io, T.name) + function (learner::AbstractLearner)(env) end """ diff --git a/src/components/policies/chance_player_policy.jl b/src/components/policies/chance_player_policy.jl index d756653..c5fd259 100644 --- a/src/components/policies/chance_player_policy.jl +++ b/src/components/policies/chance_player_policy.jl @@ -6,7 +6,7 @@ struct ChancePlayerPolicy <: AbstractPolicy rng::AbstractRNG end -ChancePlayerPolicy(;seed=nothing) = ChancePlayerPolicy(MersenneTwister(seed)) +ChancePlayerPolicy(;rng=Random.GLOBAL_RNG) = ChancePlayerPolicy(rng) function (p::ChancePlayerPolicy)(env) v = rand(p.rng) diff --git a/src/core/experiment.jl b/src/core/experiment.jl index 39796ff..7a2fb55 100644 --- a/src/core/experiment.jl +++ b/src/core/experiment.jl @@ -11,6 +11,11 @@ Base.@kwdef mutable struct Experiment description::String end +function Base.show(io::IO, x::Experiment) + display(Markdown.parse(x.description)) + AbstractTrees.print_tree(io, StructTree(x),15) +end + macro experiment_cmd(s) Experiment(s) end @@ -33,7 +38,6 @@ function Experiment(s::String) end function Base.run(x::Experiment) - display(Markdown.parse(x.description)) run(x.agent, x.env, x.stop_condition, x.hook) x end diff --git a/src/core/hooks.jl b/src/core/hooks.jl index bc6e503..a56de76 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -41,6 +41,8 @@ struct ComposedHook{T<:Tuple} <: AbstractHook ComposedHook(hooks...) = new{typeof(hooks)}(hooks) end +Base.summary(io::IO, hook::ComposedHook) = print(io, "ComposedHook") + function (hook::ComposedHook)(stage::AbstractStage, args...; kw...) for h in hook.hooks h(stage, args...; kw...) diff --git a/src/core/run.jl b/src/core/run.jl index 5c23202..c370647 100644 --- a/src/core/run.jl +++ b/src/core/run.jl @@ -31,11 +31,11 @@ function run( reset!(env) agent(PRE_EPISODE_STAGE, env) hook(PRE_EPISODE_STAGE, agent, env) - action = agent(PRE_ACT_STAGE) + action = agent(PRE_ACT_STAGE, env) hook(PRE_ACT_STAGE, agent, env, action) else stop_condition(agent, env) && break - action = agent(PRE_ACT_STAGE) + action = agent(PRE_ACT_STAGE, env) hook(PRE_ACT_STAGE, agent, env, action) end end diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index 69aceb5..3f2651d 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -1,4 +1,4 @@ -export orthogonal +export glorot_uniform, glorot_normal, orthogonal import Flux: glorot_uniform, glorot_normal @@ -11,6 +11,9 @@ glorot_uniform(rng::AbstractRNG, dims...) = glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(Flux.nfan(dims...))) +glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...) +glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...) + # https://github.com/FluxML/Flux.jl/pull/1171/ # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/Orthogonal function orthogonal_matrix(rng::AbstractRNG, nrow, ncol) @@ -27,3 +30,6 @@ function orthogonal(rng::AbstractRNG, d1, rest_dims...) end orthogonal(dims...) = orthogonal(Random.GLOBAL_RNG, dims...) +orthogonal(rng::AbstractRNG) = (dims...) -> orthogonal(rng, dims...) + +Base.summary(io::IO, t::T) where T<:Flux.Chain = print(io, T.name) \ No newline at end of file diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index cabfec2..c7a2cad 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -12,7 +12,9 @@ Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), Flux.testmode!(p::AbstractPolicy, mode = true) = @error "someone forgets to implement this method!!!" -Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p)) +Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),15) + +Base.summary(io::IO, t::T) where T<:AbstractPolicy = print(io, T.name) function save(f::String, p::AbstractPolicy) policy = cpu(p) diff --git a/src/utils/device.jl b/src/utils/device.jl index 7882a01..92bbc29 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -3,6 +3,7 @@ export device, send_to_host, send_to_device using Flux using CUDA using Adapt +using Random import CUDA:device @@ -27,6 +28,14 @@ device(::Array) = Val(:cpu) device(x::Tuple{}) = nothing device(x::NamedTuple{(),Tuple{}}) = nothing +function device(x::Random.AbstractRNG) + if x isa CUDA.CURAND.RNG + Val(:gpu) + else + Val(:cpu) + end +end + function device(x::Union{Tuple,NamedTuple}) d1 = device(x[1]) if isnothing(d1) diff --git a/src/utils/printing.jl b/src/utils/printing.jl index e4f5397..8396ca1 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -1,6 +1,8 @@ export StructTree using AbstractTrees +using Random +using ProgressMeter const AT = AbstractTrees @@ -9,7 +11,7 @@ struct StructTree{X} end AT.children(t::StructTree{X}) where X = Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X)) -AT.children(t::StructTree{<:AbstractArray}) = () +AT.children(t::StructTree{T}) where T<:Union{AbstractArray, MersenneTwister, ProgressMeter.Progress} = () AT.children(t::Pair{Symbol, <:StructTree}) = children(last(t)) AT.printnode(io::IO, t::StructTree) = summary(io, t.x) @@ -29,7 +31,7 @@ function AT.printnode(io::IO, t::StructTree{String}) if i > 79 print(io, "\"s[1:79]...\"") else - print(io, "\"$(s[1:i])...\"") + print(io, "\"$(s[1:i-1])...\"") end end end @@ -38,3 +40,7 @@ function AT.printnode(io::IO, t::Pair{Symbol, <:StructTree}) print(io, first(t), " => ") AT.printnode(io, last(t)) end + +function AT.printnode(io::IO, t::Pair{Symbol, <:StructTree{<:Tuple}}) + print(io, first(t)) +end From 34b649cb47bb6f73bcf38d61e49bedb3e760d013 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 26 Jul 2020 21:14:33 +0800 Subject: [PATCH 09/17] allow setting max_depth when printing struct --- src/components/agents/abstract_agent.jl | 3 ++- src/core/experiment.jl | 2 +- src/core/stop_conditions.jl | 4 ++-- src/extensions/ReinforcementLearningBase.jl | 2 +- src/utils/printing.jl | 4 ++-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl index 3b4b378..a056bb7 100644 --- a/src/components/agents/abstract_agent.jl +++ b/src/components/agents/abstract_agent.jl @@ -74,5 +74,6 @@ Training(s::T) where {T<:AbstractStage} = Training{T}() struct Testing{T<:AbstractStage} end Testing(s::T) where {T<:AbstractStage} = Testing{T}() -Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),15) Base.summary(io::IO, agent::T) where {T<:AbstractAgent} = print(io, T.name) + +Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),get(io, :max_depth, 15)) diff --git a/src/core/experiment.jl b/src/core/experiment.jl index 7a2fb55..bfe2518 100644 --- a/src/core/experiment.jl +++ b/src/core/experiment.jl @@ -13,7 +13,7 @@ end function Base.show(io::IO, x::Experiment) display(Markdown.parse(x.description)) - AbstractTrees.print_tree(io, StructTree(x),15) + AbstractTrees.print_tree(io, StructTree(x),get(io, :max_depth, 15)) end macro experiment_cmd(s) diff --git a/src/core/stop_conditions.jl b/src/core/stop_conditions.jl index 0caa56c..d16699d 100644 --- a/src/core/stop_conditions.jl +++ b/src/core/stop_conditions.jl @@ -41,7 +41,7 @@ end function StopAfterStep(step; cur = 1, is_show_progress = true) if is_show_progress - progress = Progress(step) + progress = Progress(step, 1) ProgressMeter.update!(progress, cur) else progress = nothing @@ -80,7 +80,7 @@ end function StopAfterEpisode(episode; cur = 0, is_show_progress = true) if is_show_progress - progress = Progress(episode) + progress = Progress(episode, 1) ProgressMeter.update!(progress, cur) else progress = nothing diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index c7a2cad..fe3a625 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -12,7 +12,7 @@ Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), Flux.testmode!(p::AbstractPolicy, mode = true) = @error "someone forgets to implement this method!!!" -Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),15) +Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),get(io, :max_depth, 15)) Base.summary(io::IO, t::T) where T<:AbstractPolicy = print(io, T.name) diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 8396ca1..8849a65 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -23,13 +23,13 @@ function AT.printnode(io::IO, t::StructTree{String}) i = findfirst('\n', s) if isnothing(i) if length(s) > 79 - print(io, "\"s[1:79]...\"") + print(io, "\"$(s[1:79])...\"") else print(io, "\"$s\"") end else if i > 79 - print(io, "\"s[1:79]...\"") + print(io, "\"$(s[1:79])...\"") else print(io, "\"$(s[1:i-1])...\"") end From 724c3dcd3b0f604897cab6f846d144f7262c3532 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 27 Jul 2020 22:20:41 +0800 Subject: [PATCH 10/17] minor fix --- src/components/explorers/epsilon_greedy_explorer.jl | 2 -- src/core/experiment.jl | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/components/explorers/epsilon_greedy_explorer.jl b/src/components/explorers/epsilon_greedy_explorer.jl index 85f04da..8bfe7a2 100644 --- a/src/components/explorers/epsilon_greedy_explorer.jl +++ b/src/components/explorers/epsilon_greedy_explorer.jl @@ -197,8 +197,6 @@ function RLBase.get_prob(s::EpsilonGreedyExplorer{<:Any,false}, values, mask) Categorical(probs) end -RLBase.reset!(s::EpsilonGreedyExplorer) = s.step = 1 - # Though we can achieve the same goal by setting the ϵ of [`EpsilonGreedyExplorer`](@ref) to 0, # the GreedyExplorer is much faster. struct GreedyExplorer <: AbstractExplorer end diff --git a/src/core/experiment.jl b/src/core/experiment.jl index bfe2518..acb51e1 100644 --- a/src/core/experiment.jl +++ b/src/core/experiment.jl @@ -38,6 +38,7 @@ function Experiment(s::String) end function Base.run(x::Experiment) + display(Markdown.parse(x.description)) run(x.agent, x.env, x.stop_condition, x.hook) x end From 690ac9062d9898649500dd059320744fb7fc5ed3 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 28 Jul 2020 17:22:26 +0800 Subject: [PATCH 11/17] decrease max_depth --- src/components/agents/abstract_agent.jl | 2 +- src/core/experiment.jl | 2 +- src/extensions/ReinforcementLearningBase.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl index a056bb7..74ed323 100644 --- a/src/components/agents/abstract_agent.jl +++ b/src/components/agents/abstract_agent.jl @@ -76,4 +76,4 @@ Testing(s::T) where {T<:AbstractStage} = Testing{T}() Base.summary(io::IO, agent::T) where {T<:AbstractAgent} = print(io, T.name) -Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),get(io, :max_depth, 15)) +Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),get(io, :max_depth, 10)) diff --git a/src/core/experiment.jl b/src/core/experiment.jl index acb51e1..febebbe 100644 --- a/src/core/experiment.jl +++ b/src/core/experiment.jl @@ -13,7 +13,7 @@ end function Base.show(io::IO, x::Experiment) display(Markdown.parse(x.description)) - AbstractTrees.print_tree(io, StructTree(x),get(io, :max_depth, 15)) + AbstractTrees.print_tree(io, StructTree(x),get(io, :max_depth, 10)) end macro experiment_cmd(s) diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index fe3a625..4255973 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -12,7 +12,7 @@ Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), Flux.testmode!(p::AbstractPolicy, mode = true) = @error "someone forgets to implement this method!!!" -Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),get(io, :max_depth, 15)) +Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),get(io, :max_depth, 10)) Base.summary(io::IO, t::T) where T<:AbstractPolicy = print(io, T.name) From 204df5164497c381890f6d860a7141942df38d34 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 28 Jul 2020 23:14:04 +0800 Subject: [PATCH 12/17] simplify printing --- src/components/agents/abstract_agent.jl | 2 -- src/components/agents/agent.jl | 4 ++-- src/components/approximators/abstract_approximator.jl | 2 -- src/components/learners/learners.jl | 2 -- src/core/hooks.jl | 2 -- src/extensions/Flux.jl | 2 -- src/extensions/ReinforcementLearningBase.jl | 3 ++- src/utils/printing.jl | 6 +++--- 8 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/components/agents/abstract_agent.jl b/src/components/agents/abstract_agent.jl index 74ed323..b65da82 100644 --- a/src/components/agents/abstract_agent.jl +++ b/src/components/agents/abstract_agent.jl @@ -74,6 +74,4 @@ Training(s::T) where {T<:AbstractStage} = Training{T}() struct Testing{T<:AbstractStage} end Testing(s::T) where {T<:AbstractStage} = Testing{T}() -Base.summary(io::IO, agent::T) where {T<:AbstractAgent} = print(io, T.name) - Base.show(io::IO, agent::AbstractAgent) = AbstractTrees.print_tree(io, StructTree(agent),get(io, :max_depth, 10)) diff --git a/src/components/agents/agent.jl b/src/components/agents/agent.jl index 8c618d2..b8fc530 100644 --- a/src/components/agents/agent.jl +++ b/src/components/agents/agent.jl @@ -16,12 +16,12 @@ Generally speaking, it does nothing but update the trajectory and policy appropr - `policy`::[`AbstractPolicy`](@ref): the policy to use - `trajectory`::[`AbstractTrajectory`](@ref): used to store transitions between an agent and an environment -- `role=:DEFAULT_PLAYER`: used to distinguish different agents +- `role=RLBase.DEFAULT_PLAYER`: used to distinguish different agents """ Base.@kwdef mutable struct Agent{P<:AbstractPolicy,T<:AbstractTrajectory,R} <: AbstractAgent policy::P trajectory::T = DummyTrajectory() - role::R = :DEFAULT_PLAYER + role::R = RLBase.DEFAULT_PLAYER is_training::Bool = true end diff --git a/src/components/approximators/abstract_approximator.jl b/src/components/approximators/abstract_approximator.jl index 2aabba6..4489c30 100644 --- a/src/components/approximators/abstract_approximator.jl +++ b/src/components/approximators/abstract_approximator.jl @@ -10,8 +10,6 @@ kinds of approximate methods (for example DNN provided by Flux or Knet). """ abstract type AbstractApproximator end -Base.summary(io::IO, t::T) where T<:AbstractApproximator = print(io, T.name) - """ update!(a::AbstractApproximator, correction) diff --git a/src/components/learners/learners.jl b/src/components/learners/learners.jl index 328b858..4071182 100644 --- a/src/components/learners/learners.jl +++ b/src/components/learners/learners.jl @@ -9,8 +9,6 @@ A learner is usually used to estimate state values, state-action values or distr """ abstract type AbstractLearner end -Base.summary(io::IO, t::T) where T<:AbstractLearner = print(io, T.name) - function (learner::AbstractLearner)(env) end """ diff --git a/src/core/hooks.jl b/src/core/hooks.jl index a56de76..bc6e503 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -41,8 +41,6 @@ struct ComposedHook{T<:Tuple} <: AbstractHook ComposedHook(hooks...) = new{typeof(hooks)}(hooks) end -Base.summary(io::IO, hook::ComposedHook) = print(io, "ComposedHook") - function (hook::ComposedHook)(stage::AbstractStage, args...; kw...) for h in hook.hooks h(stage, args...; kw...) diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index 3f2651d..2fe992b 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -31,5 +31,3 @@ end orthogonal(dims...) = orthogonal(Random.GLOBAL_RNG, dims...) orthogonal(rng::AbstractRNG) = (dims...) -> orthogonal(rng, dims...) - -Base.summary(io::IO, t::T) where T<:Flux.Chain = print(io, T.name) \ No newline at end of file diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index 4255973..2fce3bf 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -14,7 +14,8 @@ Flux.testmode!(p::AbstractPolicy, mode = true) = Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),get(io, :max_depth, 10)) -Base.summary(io::IO, t::T) where T<:AbstractPolicy = print(io, T.name) +AbstractTrees.printnode(io::IO, t::StructTree{<:AbstractEnv}) = + print(io, "$(RLBase.get_name(t.x)): $(join([f(t.x) for f in RLBase.get_env_traits()], ","))") function save(f::String, p::AbstractPolicy) policy = cpu(p) diff --git a/src/utils/printing.jl b/src/utils/printing.jl index 8849a65..8dc5c43 100644 --- a/src/utils/printing.jl +++ b/src/utils/printing.jl @@ -11,12 +11,12 @@ struct StructTree{X} end AT.children(t::StructTree{X}) where X = Tuple(f => StructTree(getfield(t.x, f)) for f in fieldnames(X)) -AT.children(t::StructTree{T}) where T<:Union{AbstractArray, MersenneTwister, ProgressMeter.Progress} = () +AT.children(t::StructTree{T}) where T<:Union{AbstractArray, MersenneTwister, ProgressMeter.Progress, Function} = () AT.children(t::Pair{Symbol, <:StructTree}) = children(last(t)) -AT.printnode(io::IO, t::StructTree) = summary(io, t.x) - AT.printnode(io::IO, t::StructTree{<:Union{Number,Symbol}}) = print(io, t.x) +AT.printnode(io::IO, t::StructTree{T}) where T = print(io, T.name) +AT.printnode(io::IO, t::StructTree{<:AbstractArray}) where T = summary(io, t.x) function AT.printnode(io::IO, t::StructTree{String}) s = t.x From d17ee39684529d01e967dfd182259d2f62297804 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 28 Jul 2020 23:19:14 +0800 Subject: [PATCH 13/17] resolve comments --- Project.toml | 1 - .../policies/chance_player_policy.jl | 18 ----------- src/components/policies/policies.jl | 2 -- src/components/policies/tabular_policy.jl | 31 ------------------- test/core/core.jl | 2 +- 5 files changed, 1 insertion(+), 53 deletions(-) delete mode 100644 src/components/policies/chance_player_policy.jl delete mode 100644 src/components/policies/tabular_policy.jl diff --git a/Project.toml b/Project.toml index d0df0e5..4213b92 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" -ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/components/policies/chance_player_policy.jl b/src/components/policies/chance_player_policy.jl deleted file mode 100644 index c5fd259..0000000 --- a/src/components/policies/chance_player_policy.jl +++ /dev/null @@ -1,18 +0,0 @@ -export ChancePlayerPolicy - -using Random - -struct ChancePlayerPolicy <: AbstractPolicy - rng::AbstractRNG -end - -ChancePlayerPolicy(;rng=Random.GLOBAL_RNG) = ChancePlayerPolicy(rng) - -function (p::ChancePlayerPolicy)(env) - v = rand(p.rng) - s = 0. - for (action, prob) in get_chance_outcome(env) - s += prob - s >= v && return action - end -end diff --git a/src/components/policies/policies.jl b/src/components/policies/policies.jl index 4c610ae..ab599f9 100644 --- a/src/components/policies/policies.jl +++ b/src/components/policies/policies.jl @@ -1,5 +1,3 @@ -include("chance_player_policy.jl") -include("tabular_policy.jl") include("V_based_policy.jl") include("Q_based_policy.jl") include("off_policy.jl") diff --git a/src/components/policies/tabular_policy.jl b/src/components/policies/tabular_policy.jl deleted file mode 100644 index 2da45d9..0000000 --- a/src/components/policies/tabular_policy.jl +++ /dev/null @@ -1,31 +0,0 @@ -export TabularPolicy - -using AbstractTrees - -## TODO: Use TabularApproximator -struct TabularPolicy{S,F,E} <: RLBase.AbstractPolicy - probs::Dict{S,Vector{Float64}} - key::F - explorer::E -end - -(p::TabularPolicy)(obs) = p.probs[p.key(obs)] |> p.explorer - -RLBase.get_prob(p::TabularPolicy, obs) = p.probs[p.key(obs)] - -function TabularPolicy(env::AbstractEnv;key=RLBase.get_state, explorer=WeightedExplorer(;is_normalized=true)) - k = key(observe(env)) - probs = Dict{typeof(k),Vector{Float64}}() - for x in PreOrderDFS(env) - if get_current_player(x) != get_chance_player(x) - obs = observe(x) - if !get_terminal(obs) - legal_actions_mask = get_legal_actions_mask(obs) - p = zeros(length(legal_actions_mask)) - p[legal_actions_mask] .= 1 / sum(legal_actions_mask) - probs[key(obs)] = p - end - end - end - TabularPolicy(probs, key, explorer) -end diff --git a/test/core/core.jl b/test/core/core.jl index d9f5a29..91e8d5c 100644 --- a/test/core/core.jl +++ b/test/core/core.jl @@ -1,5 +1,5 @@ @testset "simple workflow" begin - env = CartPoleEnv{Float32}() |> StateOverriddenEnv(;deep_copy_state=deepcopy) + env = CartPoleEnv{Float32}() |> StateOverriddenEnv(deepcopy) agent = Agent(; policy = RandomPolicy(env), trajectory = VectorialCompactSARTSATrajectory(; state_type = Vector{Float32}), From 2249bc7b96c21d05082fb99415c811c625af94b0 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 30 Jul 2020 10:13:20 +0800 Subject: [PATCH 14/17] support RandomStartPolicy --- src/extensions/ReinforcementLearningBase.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/extensions/ReinforcementLearningBase.jl b/src/extensions/ReinforcementLearningBase.jl index 2fce3bf..041d0c5 100644 --- a/src/extensions/ReinforcementLearningBase.jl +++ b/src/extensions/ReinforcementLearningBase.jl @@ -12,6 +12,10 @@ Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), Flux.testmode!(p::AbstractPolicy, mode = true) = @error "someone forgets to implement this method!!!" +function Flux.testmode!(p::RandomStartPolicy, mode = true) + testmode!(p.policy, mode) +end + Base.show(io::IO, p::AbstractPolicy) = AbstractTrees.print_tree(io, StructTree(p),get(io, :max_depth, 10)) AbstractTrees.printnode(io::IO, t::StructTree{<:AbstractEnv}) = From 24f8a403d51f97a6013f02e6a8f5026a413494a7 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 30 Jul 2020 18:08:18 +0800 Subject: [PATCH 15/17] fix #73 --- src/core/hooks.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/core/hooks.jl b/src/core/hooks.jl index bc6e503..9535144 100644 --- a/src/core/hooks.jl +++ b/src/core/hooks.jl @@ -60,6 +60,12 @@ struct EmptyHook <: AbstractHook end const EMPTY_HOOK = EmptyHook() +##### +# display +##### + +Base.display(::AbstractStage, agent, env, args...; kwargs...) = display(env) + ##### # StepsPerEpisode ##### From f1708a03600c737e9194216a50604b5a1a18698e Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 4 Aug 2020 10:06:20 +0800 Subject: [PATCH 16/17] ignore gpu related tests when CUDA is not available --- test/utils/device.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/utils/device.jl b/test/utils/device.jl index 8bde3e8..b809b43 100644 --- a/test/utils/device.jl +++ b/test/utils/device.jl @@ -5,9 +5,11 @@ @test device(Conv((2, 2), 1 => 16, relu)) == Val(:cpu) @test device(Chain(x -> x .^ 2, Dense(2, 3))) == Val(:cpu) - @test device(rand(2) |> gpu) == Val(:gpu) - @test device(Dense(2, 3) |> gpu) == Val(:gpu) - @test device(Conv((2, 2), 1 => 16, relu) |> gpu) == Val(:gpu) - @test device(Chain(x -> x .^ 2, Dense(2, 3)) |> gpu) == Val(:gpu) + if CUDA.functional() + @test device(rand(2) |> gpu) == Val(:gpu) + @test device(Dense(2, 3) |> gpu) == Val(:gpu) + @test device(Conv((2, 2), 1 => 16, relu) |> gpu) == Val(:gpu) + @test device(Chain(x -> x .^ 2, Dense(2, 3)) |> gpu) == Val(:gpu) + end end From 41e39593e670d2f0443ee7ddaa5d3110e9621556 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 4 Aug 2020 10:48:37 +0800 Subject: [PATCH 17/17] update compat --- Project.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4213b92..2cf117b 100644 --- a/Project.toml +++ b/Project.toml @@ -25,17 +25,22 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AbstractTrees = "0.3" +Adapt = "2" BSON = "0.2" +CUDA = "1" Distributions = "0.22, 0.23" FillArrays = "0.8" Flux = "0.11" +GPUArrays = "5" ImageTransformations = "0.8" JLD = "0.10" MacroTools = "0.5" ProgressMeter = "1.2" ReinforcementLearningBase = "0.8" -Setfield = "0.6" +Setfield = "0.6, 0.7" StatsBase = "0.32, 0.33" +Zygote = "0.5" julia = "1.3" [extras]