Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CUDA = "1"
Flux = "0.11"
MacroTools = "0.5"
ReinforcementLearningBase = "0.8"
ReinforcementLearningCore = "0.4"
ReinforcementLearningCore = "0.4.1"
Requires = "1"
Setfield = "0.6, 0.7"
StatsBase = "0.32, 0.33"
Expand Down
29 changes: 16 additions & 13 deletions src/algorithms/dqns/basic_dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,24 @@ function BasicDQNLearner(;
)
end

function RLBase.update!(learner::BasicDQNLearner, t::AbstractTrajectory)
length(t) < learner.min_replay_history && return
function RLBase.update!(learner::BasicDQNLearner, T::AbstractTrajectory)
length(T[:terminal]) < learner.min_replay_history && return

inds = rand(learner.rng, 1:length(t), learner.batch_size)
batch = map(get_trace(t, :state, :action, :reward, :terminal, :next_state)) do x
consecutive_view(x, inds)
end
Q = learner.approximator
D = device(Q)
γ = learner.γ
loss_func = learner.loss_func
batch_size = learner.batch_size

Q, γ, loss_func, batch_size =
learner.approximator, learner.γ, learner.loss_func, learner.batch_size
s, r, t, s′ = map(
x -> send_to_device(device(Q), x),
(batch.state, batch.reward, batch.terminal, batch.next_state),
)
a = CartesianIndex.(batch.action, 1:batch_size)
inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size)

s = send_to_device(D, consecutive_view(T[:state], inds))
a = consecutive_view(T[:action], inds)
r = send_to_device(D, consecutive_view(T[:reward], inds))
t = send_to_device(D, consecutive_view(T[:terminal], inds))
s′ = send_to_device(D, consecutive_view(T[:next_state], inds))

a = CartesianIndex.(a, 1:batch_size)

gs = gradient(params(Q)) do
q = Q(s)[a]
Expand Down
20 changes: 10 additions & 10 deletions src/algorithms/dqns/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
γ = learner.γ

# 1. sample indices based on priority
valid_ind_range = isnothing(s) ? (1:(length(t)-h)) : (s:(length(t)-h))
valid_ind_range = isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
if t isa CircularCompactPSARTSATrajectory
inds = Vector{Int}(undef, n)
priorities = Vector{Float32}(undef, n)
for i in 1:n
ind, p = sample(learner.rng, get_trace(t, :priority))
ind, p = sample(learner.rng, t[:priority])
while ind ∉ valid_ind_range
ind, p = sample(learner.rng, get_trace(t, :priority))
ind, p = sample(learner.rng, t[:priority])
end
inds[i] = ind
priorities[i] = p
Expand All @@ -29,11 +29,11 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
end

# 2. extract SARTS
states = consecutive_view(get_trace(t, :state), inds; n_stack = s)
actions = consecutive_view(get_trace(t, :action), inds)
next_states = consecutive_view(get_trace(t, :state), inds .+ h; n_stack = s)
consecutive_rewards = consecutive_view(get_trace(t, :reward), inds; n_horizon = h)
consecutive_terminals = consecutive_view(get_trace(t, :terminal), inds; n_horizon = h)
states = consecutive_view(t[:state], inds; n_stack = s)
actions = consecutive_view(t[:action], inds)
next_states = consecutive_view(t[:state], inds .+ h; n_stack = s)
consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
rewards, terminals = zeros(Float32, n), fill(false, n)

rewards = discount_rewards_reduced(
Expand All @@ -57,7 +57,7 @@ end

function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)
learner = p.learner
length(t) < learner.min_replay_history && return
length(t[:terminal]) < learner.min_replay_history && return

learner.update_step += 1

Expand All @@ -71,7 +71,7 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)

if t isa CircularCompactPSARTSATrajectory
priorities = update!(p.learner, experience)
get_trace(t, :priority)[inds] .= priorities
t[:priority][inds] .= priorities
else
update!(p.learner, experience)
end
Expand Down
14 changes: 7 additions & 7 deletions src/algorithms/dqns/dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ end
Flux.squeezebatch

function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
length(t) < learner.min_replay_history && return
length(t[:terminal]) < learner.min_replay_history && return

learner.update_step += 1

Expand Down Expand Up @@ -144,13 +144,13 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner)
n = learner.batch_size
γ = learner.γ

valid_ind_range = isnothing(s) ? (1:(length(t)-h)) : (s:(length(t)-h))
valid_ind_range = isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
inds = rand(learner.rng, valid_ind_range, n)
states = consecutive_view(get_trace(t, :state), inds; n_stack = s)
actions = consecutive_view(get_trace(t, :action), inds)
next_states = consecutive_view(get_trace(t, :state), inds .+ h; n_stack = s)
consecutive_rewards = consecutive_view(get_trace(t, :reward), inds; n_horizon = h)
consecutive_terminals = consecutive_view(get_trace(t, :terminal), inds; n_horizon = h)
states = consecutive_view(t[:state], inds; n_stack = s)
actions = consecutive_view(t[:action], inds)
next_states = consecutive_view(t[:state], inds .+ h; n_stack = s)
consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
rewards, terminals = zeros(Float32, n), fill(false, n)

# make sure that we only consider experiences in current episode
Expand Down
10 changes: 5 additions & 5 deletions src/algorithms/policy_gradient/A2C.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ end
function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory)
isfull(t) || return

states = get_trace(t, :state)
actions = get_trace(t, :action)
rewards = get_trace(t, :reward)
terminals = get_trace(t, :terminal)
next_state = select_last_frame(get_trace(t, :next_state))
states = t[:state]
actions = t[:action]
rewards = t[:reward]
terminals = t[:terminal]
next_state = select_last_frame(t[:next_state])

AC = learner.approximator
γ = learner.γ
Expand Down
10 changes: 5 additions & 5 deletions src/algorithms/policy_gradient/A2CGAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ end
function RLBase.update!(learner::A2CGAELearner, t::AbstractTrajectory)
isfull(t) || return

states = get_trace(t, :state)
actions = get_trace(t, :action)
rewards = get_trace(t, :reward)
terminals = get_trace(t, :terminal)
rollout = t[:state]
states = t[:state]
actions = t[:action]
rewards = t[:reward]
terminals = t[:terminal]
rollout = t[:full_state]

AC = learner.approximator
γ = learner.γ
Expand Down
13 changes: 8 additions & 5 deletions src/algorithms/policy_gradient/ddpg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,16 @@ function (p::DDPGPolicy)(env)
end
end

function RLBase.update!(p::DDPGPolicy, t::CircularCompactSARTSATrajectory)
length(t) > p.update_after || return
function RLBase.update!(p::DDPGPolicy, traj::CircularCompactSARTSATrajectory)
length(traj[:terminal]) > p.update_after || return
p.step % p.update_every == 0 || return

inds = rand(p.rng, 1:(length(t)-1), p.batch_size)
SARTS = (:state, :action, :reward, :terminal, :next_state)
s, a, r, t, s′ = map(x -> select_last_dim(get_trace(t, x), inds), SARTS)
inds = rand(p.rng, 1:(length(traj[:terminal])-1), p.batch_size)
s = select_last_dim(traj[:state], inds)
a = select_last_dim(traj[:action], inds)
r = select_last_dim(traj[:reward], inds)
t = select_last_dim(traj[:terminal], inds)
s′ = select_last_dim(traj[:next_state], inds)

A = p.behavior_actor
C = p.behavior_critic
Expand Down
14 changes: 7 additions & 7 deletions src/algorithms/policy_gradient/ppo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ end
function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
isfull(t) || return

states = get_trace(t, :state)
actions = get_trace(t, :action)
action_log_probs = get_trace(t, :action_log_prob)
rewards = get_trace(t, :reward)
terminals = get_trace(t, :terminal)
states_plus = t[:state]
states = t[:state]
actions = t[:action]
action_log_probs = t[:action_log_prob]
rewards = t[:reward]
terminals = t[:terminal]
states_plus = t[:full_state]

rng = learner.rng
AC = learner.approximator
Expand Down Expand Up @@ -126,7 +126,7 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
rand_inds = shuffle!(rng, Vector(1:n_envs*n_rollout))
for i in 1:n_microbatches
inds = rand_inds[(i-1)*microbatch_size+1:i*microbatch_size]
s = send_to_device(D, select_last_dim(states_flatten, inds) |> copy) # !!! must copy here
s = send_to_device(D, select_last_dim(states_flatten, inds))
a = vec(actions)[inds]
r = send_to_device(D, vec(returns)[inds])
log_p = send_to_device(D, vec(action_log_probs)[inds])
Expand Down
110 changes: 22 additions & 88 deletions src/algorithms/policy_gradient/ppo_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,102 +2,36 @@ export PPOTrajectory

using MacroTools

struct PPOTrajectory{T<:CircularCompactSARTSATrajectory,P,names,types} <:
AbstractTrajectory{names,types}
trajectory::T
action_log_prob::P
end
const PPOTrajectory = CombinedTrajectory{
<:SharedTrajectory{<:CircularArrayBuffer, <:NamedTuple{(:action_log_prob, :next_action_log_prob, :full_action_log_prob)}},
<:CircularCompactSARTSATrajectory,
}

function PPOTrajectory(;
capacity,
action_log_prob_size = (),
action_log_prob_type = Float32,
kw...,
kw...
)
t = CircularCompactSARTSATrajectory(; capacity = capacity, kw...)
p = CircularArrayBuffer{action_log_prob_type}(action_log_prob_size..., capacity + 1)
names = typeof(t).parameters[1]
types = typeof(t).parameters[2]
PPOTrajectory{
typeof(t),
typeof(p),
(
:state,
:action,
:action_log_prob,
:reward,
:terminal,
:next_state,
:next_action,
:next_action_log_prob,
),
Tuple{
types.parameters[1:2]...,
frame_type(p),
types.parameters[3:end]...,
frame_type(p),
},
}(
t,
p,
CombinedTrajectory(
SharedTrajectory(CircularArrayBuffer{action_log_prob_type}(action_log_prob_size..., capacity + 1), :action_log_prob),
CircularCompactSARTSATrajectory(;capacity=capacity,kw...),
)
end

MacroTools.@forward PPOTrajectory.trajectory Base.length, Base.isempty, RLCore.isfull

function RLCore.get_trace(t::PPOTrajectory, s::Symbol)
if s == :action_log_prob
select_last_dim(
t.action_log_prob,
1:(nframes(t.action_log_prob) > 1 ? nframes(t.action_log_prob) - 1 :
nframes(t.action_log_prob)),
)
elseif s == :next_action_log_prob
select_last_dim(t.action_log_prob, 2:nframes(t.action_log_prob))
else
get_trace(t.trajectory, s)
end
end
const PPOActionMaskTrajectory = CombinedTrajectory{
<:SharedTrajectory{<:CircularArrayBuffer, <:NamedTuple{(:action_log_prob, :next_action_log_prob, :full_action_log_prob)}},
<:CircularCompactSALRTSALTrajectory,
}

Base.getindex(t::PPOTrajectory, s::Symbol) =
s == :action_log_prob ? t.action_log_prob : t.trajectory[s]

function Base.getindex(p::PPOTrajectory, i::Int)
s, a, r, t, s′, a′ = p.trajectory[i]
(
state = s,
action = a,
action_log_prob = select_last_dim(p.action_log_prob, i),
reward = r,
terminal = t,
next_state = s′,
next_action = a′,
next_action_log_prob = select_last_dim(p.action_log_prob, i + 1),
function PPOActionMaskTrajectory(;
capacity,
action_log_prob_size = (),
action_log_prob_type = Float32,
kw...
)
CombinedTrajectory(
SharedTrajectory(CircularArrayBuffer{action_log_prob_type}(action_log_prob_size..., capacity + 1), :action_log_prob),
CircularCompactSALRTSALTrajectory(;capacity=capacity,kw...),
)
end

function Base.empty!(b::PPOTrajectory)
empty!(b.action_log_prob)
empty!(b.trajectory)
end

function Base.push!(b::PPOTrajectory, kv::Pair{Symbol})
k, v = kv
if k == :action_log_prob || k == :next_action_log_prob
push!(b.action_log_prob, v)
else
push!(b.trajectory, kv)
end
end

function Base.pop!(t::PPOTrajectory, s::Symbol)
if s == :action_log_prob || s == :next_action_log_prob
pop!(t.action_log_prob)
else
pop!(t.trajectory, s)
end
end

function Base.pop!(t::PPOTrajectory)
(pop!(t.trajectory)..., action_log_prob = pop!(t.action_log_prob))
end
end
4 changes: 2 additions & 2 deletions src/experiments/atari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ function RLCore.Experiment(
rng = MersenneTwister(seed)
if isnothing(save_dir)
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_A2C_Atari_$(name)_$(t)")
save_dir = joinpath(pwd(), "checkpoints", "rlpyt_A2C_Atari_$(name)_$(t)")
end

lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
Expand Down Expand Up @@ -716,7 +716,7 @@ function RLCore.Experiment(
rng = MersenneTwister(seed)
if isnothing(save_dir)
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_PPO_Atari_$(name)_$(t)")
save_dir = joinpath(pwd(), "checkpoints", "rlpyt_PPO_Atari_$(name)_$(t)")
end

lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
Expand Down