Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/algorithms/dqns/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
# 1. sample indices based on priority
valid_ind_range =
isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
if t isa CircularCompactPSARTSATrajectory
if haskey(t, :priority)
inds = Vector{Int}(undef, n)
priorities = Vector{Float32}(undef, n)
for i in 1:n
Expand All @@ -29,10 +29,21 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
priorities = nothing
end

next_inds = inds .+ h

# 2. extract SARTS
states = consecutive_view(t[:state], inds; n_stack = s)
actions = consecutive_view(t[:action], inds)
next_states = consecutive_view(t[:state], inds .+ h; n_stack = s)
next_states = consecutive_view(t[:state], next_inds; n_stack = s)

if haskey(t, :legal_actions_mask)
legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds)
next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], inds)
else
legal_actions_mask = nothing
next_legal_actions_mask = nothing
end

consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
rewards, terminals = zeros(Float32, n), fill(false, n)
Expand All @@ -48,10 +59,12 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners)
inds,
(
states = states,
legal_actions_mask = legal_actions_mask,
actions = actions,
rewards = rewards,
terminals = terminals,
next_states = next_states,
next_legal_actions_mask = next_legal_actions_mask,
priorities = priorities,
)
end
Expand All @@ -70,23 +83,25 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)

inds, experience = extract_experience(t, p.learner)

if t isa CircularCompactPSARTSATrajectory
if haskey(t, :priority)
priorities = update!(p.learner, experience)
t[:priority][inds] .= priorities
else
update!(p.learner, experience)
end
end

function (agent::Agent{<:QBasedPolicy{<:PERLearners},<:CircularCompactPSARTSATrajectory})(
function (agent::Agent{<:QBasedPolicy{<:PERLearners}})(
::RLCore.Training{PostActStage},
env,
)
push!(
agent.trajectory;
reward = get_reward(env),
terminal = get_terminal(env),
priority = agent.policy.learner.default_priority,
)
if haskey(agent.trajectory, :priority)
push!(agent.trajectory; priority = agent.policy.learner.default_priority)
end
nothing
end
47 changes: 35 additions & 12 deletions src/algorithms/dqns/dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,20 @@ end
The state of the observation is assumed to have been stacked,
if `!isnothing(stack_size)`.
"""
(learner::DQNLearner)(env) =
env |>
function (learner::DQNLearner)(env)
probs = env |>
get_state |>
x ->
Flux.unsqueeze(x, ndims(x) + 1) |>
x ->
send_to_device(device(learner.approximator), x) |>
learner.approximator |>
send_to_host |>
Flux.squeezebatch
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
x -> send_to_device(device(learner.approximator), x) |>
learner.approximator |>
vec |>
send_to_host

if ActionStyle(env) === FULL_ACTION_SET
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
end
probs
end

function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
length(t[:terminal]) < learner.min_replay_history && return
Expand Down Expand Up @@ -124,10 +128,16 @@ function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory)
terminals = send_to_device(D, experience.terminals)
next_states = send_to_device(D, experience.next_states)

target_q = Qₜ(next_states)
if haskey(t, :next_legal_actions_mask)
target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, t[:next_legal_actions_mask]))
end

q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′

gs = gradient(params(Q)) do
q = Q(states)[actions]
q′ = dropdims(maximum(Qₜ(next_states); dims = 1), dims = 1)
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′
loss = loss_func(G, q)
ignore() do
learner.loss = loss
Expand All @@ -147,9 +157,20 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner)
valid_ind_range =
isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
inds = rand(learner.rng, valid_ind_range, n)
next_inds = inds .+ h

states = consecutive_view(t[:state], inds; n_stack = s)
actions = consecutive_view(t[:action], inds)
next_states = consecutive_view(t[:state], inds .+ h; n_stack = s)
next_states = consecutive_view(t[:state], next_inds; n_stack = s)

if haskey(t, :legal_actions_mask)
legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds)
next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], next_inds)
else
legal_actions_mask = nothing
next_legal_actions_mask = nothing
end

consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
rewards, terminals = zeros(Float32, n), fill(false, n)
Expand All @@ -167,9 +188,11 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner)
end
(
states = states,
legal_actions_mask = legal_actions_mask,
actions = actions,
rewards = rewards,
terminals = terminals,
next_states = next_states,
next_legal_actions_mask = next_legal_actions_mask,
)
end
17 changes: 13 additions & 4 deletions src/algorithms/dqns/iqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ function (learner::IQNLearner)(env)
τ = rand(learner.device_rng, Float32, learner.K, 1)
τₑₘ = embed(τ, learner.Nₑₘ)
quantiles = learner.approximator(state, τₑₘ)
vec(mean(quantiles; dims = 2)) |> send_to_host
probs = vec(mean(quantiles; dims = 2)) |> send_to_host
if ActionStyle(env) === FULL_ACTION_SET
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
end
probs
end

embed(x, Nₑₘ) = cos.(Float32(π) .* (1:Nₑₘ) .* reshape(x, 1, :))
Expand All @@ -180,7 +184,13 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
τ′ = rand(learner.device_rng, Float32, N′, batch_size) # TODO: support β distribution
τₑₘ′ = embed(τ′, Nₑₘ)
zₜ = Zₜ(s′, τₑₘ′)
aₜ = argmax(mean(zₜ, dims = 2), dims = 1)
avg_zₜ = mean(zₜ, dims = 2)

if !isnothing(batch.next_legal_actions_mask)
avg_zₜ .+= typemin(eltype(avg_zₜ)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
end

aₜ = argmax(avg_zₜ, dims = 1)
aₜ = aₜ .+ typeof(aₜ)(CartesianIndices((0, 0:N′-1, 0)))
qₜ = reshape(zₜ[aₜ], :, batch_size)
target = reshape(r, 1, batch_size) .+ learner.γ * reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast
Expand Down Expand Up @@ -214,8 +224,7 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
huber_loss ./ κ
loss_per_quantile = reshape(sum(raw_loss; dims = 1), N, batch_size)
loss_per_element = mean(loss_per_quantile; dims = 1) # use as priorities
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size :
mean(loss_per_element)
loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size : mean(loss_per_element)
ignore() do
# @assert all(loss_per_element .>= 0)
is_use_PER && (
Expand Down
33 changes: 21 additions & 12 deletions src/algorithms/dqns/prioritized_dqn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,20 @@ end
The state of the observation is assumed to have been stacked,
if `!isnothing(stack_size)`.
"""
(learner::PrioritizedDQNLearner)(env) =
env |>
function (learner::PrioritizedDQNLearner)(env)
probs = env |>
get_state |>
x ->
Flux.unsqueeze(x, ndims(x) + 1) |>
x ->
send_to_device(device(learner.approximator), x) |>
learner.approximator |>
send_to_host |>
Flux.squeezebatch
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
x -> send_to_device(device(learner.approximator), x) |>
learner.approximator |>
vec |>
send_to_host

if ActionStyle(env) === FULL_ACTION_SET
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
end
probs
end

function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
Q, Qₜ, γ, β, loss_func, update_horizon, batch_size = learner.approximator,
Expand All @@ -132,11 +136,16 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
weights ./= maximum(weights)
weights = send_to_device(D, weights)

target_q = Qₜ(next_states)
if !isnothing(batch.next_legal_actions_mask)
target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
end

q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′

gs = gradient(params(Q)) do
q = Q(states)[actions]
q′ = dropdims(maximum(Qₜ(next_states); dims = 1), dims = 1)
G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′

batch_losses = loss_func(G, q)
loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size
ignore() do
Expand Down
71 changes: 36 additions & 35 deletions src/algorithms/dqns/rainbow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,47 +126,43 @@ function (learner::RainbowLearner)(env)
state = Flux.unsqueeze(state, ndims(state) + 1)
logits = learner.approximator(state)
q = learner.support .* softmax(reshape(logits, :, learner.n_actions))
# probs = vec(sum(q, dims=1)) .+ legal_action
vec(sum(q, dims = 1)) |> send_to_host
probs = vec(sum(q, dims = 1)) |> send_to_host
if ActionStyle(env) === FULL_ACTION_SET
probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env))
end
probs
end

function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
Q,
Qₜ,
γ,
β,
loss_func,
n_atoms,
n_actions,
support,
delta_z,
update_horizon,
batch_size = learner.approximator,
learner.target_approximator,
learner.γ,
learner.β_priority,
learner.loss_func,
learner.n_atoms,
learner.n_actions,
learner.support,
learner.delta_z,
learner.update_horizon,
learner.batch_size

Q = learner.approximator
Qₜ = learner.target_approximator
γ = learner.γ
β = learner.β_priority
loss_func = learner.loss_func
n_atoms = learner.n_atoms
n_actions = learner.n_actions
support = learner.support
delta_z = learner.delta_z
update_horizon = learner.update_horizon
batch_size = learner.batch_size
D = device(Q)
states, rewards, terminals, next_states = map(
x -> send_to_device(D, x),
(batch.states, batch.rewards, batch.terminals, batch.next_states),
)
states = send_to_device(D, batch.states)
rewards = send_to_device(D, batch.rewards)
terminals = send_to_device(D, batch.terminals)
next_states = send_to_device(D, batch.next_states)

actions = CartesianIndex.(batch.actions, 1:batch_size)

target_support =
reshape(rewards, 1, :) .+
(reshape(support, :, 1) * reshape((γ^update_horizon) .* (1 .- terminals), 1, :))

next_logits = Qₜ(next_states)
next_probs = reshape(softmax(reshape(next_logits, n_atoms, :)), n_atoms, n_actions, :)
next_q = reshape(sum(support .* next_probs, dims = 1), n_actions, :)
# next_q_argmax = argmax(cpu(next_q .+ next_legal_actions), dims=1)
if !isnothing(batch.next_legal_actions_mask)
next_q .+= typemin(eltype(next_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask))
end
next_prob_select = select_best_probs(next_probs, next_q)

target_distribution = project_distribution(
Expand All @@ -178,18 +174,23 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
learner.Vₘₐₓ,
)

updated_priorities = Vector{Float32}(undef, batch_size)
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
weights ./= maximum(weights)
weights = send_to_device(D, weights)
is_use_PER = !isnothing(batch.priorities) # is use Prioritized Experience Replay
if is_use_PER
updated_priorities = Vector{Float32}(undef, batch_size)
weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β)
weights ./= maximum(weights)
weights = send_to_device(D, weights)
end

gs = gradient(Flux.params(Q)) do
logits = reshape(Q(states), n_atoms, n_actions, :)
select_logits = logits[:, actions]
batch_losses = loss_func(select_logits, target_distribution)
loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size
loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size : mean(batch_losses)
ignore() do
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))
if is_use_PER
updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β))
end
learner.loss = loss
end
loss
Expand Down
Loading