From e437baef38f8f83e34dc51ce1a370b14681b9dcb Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 16 Aug 2020 18:22:05 +0800 Subject: [PATCH] support legal_actions_mask --- src/algorithms/dqns/common.jl | 25 +++++++-- src/algorithms/dqns/dqn.jl | 47 ++++++++++++----- src/algorithms/dqns/iqn.jl | 17 ++++-- src/algorithms/dqns/prioritized_dqn.jl | 33 +++++++----- src/algorithms/dqns/rainbow.jl | 71 +++++++++++++------------- src/algorithms/policy_gradient/A2C.jl | 28 +++++++--- src/algorithms/policy_gradient/ppo.jl | 23 +++++++-- 7 files changed, 166 insertions(+), 78 deletions(-) diff --git a/src/algorithms/dqns/common.jl b/src/algorithms/dqns/common.jl index f2d735b..75106ca 100644 --- a/src/algorithms/dqns/common.jl +++ b/src/algorithms/dqns/common.jl @@ -13,7 +13,7 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners) # 1. sample indices based on priority valid_ind_range = isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h)) - if t isa CircularCompactPSARTSATrajectory + if haskey(t, :priority) inds = Vector{Int}(undef, n) priorities = Vector{Float32}(undef, n) for i in 1:n @@ -29,10 +29,21 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners) priorities = nothing end + next_inds = inds .+ h + # 2. extract SARTS states = consecutive_view(t[:state], inds; n_stack = s) actions = consecutive_view(t[:action], inds) - next_states = consecutive_view(t[:state], inds .+ h; n_stack = s) + next_states = consecutive_view(t[:state], next_inds; n_stack = s) + + if haskey(t, :legal_actions_mask) + legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds) + next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], inds) + else + legal_actions_mask = nothing + next_legal_actions_mask = nothing + end + consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h) consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h) rewards, terminals = zeros(Float32, n), fill(false, n) @@ -48,10 +59,12 @@ function extract_experience(t::AbstractTrajectory, learner::PERLearners) inds, ( states = states, + legal_actions_mask = legal_actions_mask, actions = actions, rewards = rewards, terminals = terminals, next_states = next_states, + next_legal_actions_mask = next_legal_actions_mask, priorities = priorities, ) end @@ -70,7 +83,7 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory) inds, experience = extract_experience(t, p.learner) - if t isa CircularCompactPSARTSATrajectory + if haskey(t, :priority) priorities = update!(p.learner, experience) t[:priority][inds] .= priorities else @@ -78,7 +91,7 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory) end end -function (agent::Agent{<:QBasedPolicy{<:PERLearners},<:CircularCompactPSARTSATrajectory})( +function (agent::Agent{<:QBasedPolicy{<:PERLearners}})( ::RLCore.Training{PostActStage}, env, ) @@ -86,7 +99,9 @@ function (agent::Agent{<:QBasedPolicy{<:PERLearners},<:CircularCompactPSARTSATra agent.trajectory; reward = get_reward(env), terminal = get_terminal(env), - priority = agent.policy.learner.default_priority, ) + if haskey(agent.trajectory, :priority) + push!(agent.trajectory; priority = agent.policy.learner.default_priority) + end nothing end diff --git a/src/algorithms/dqns/dqn.jl b/src/algorithms/dqns/dqn.jl index eb5689c..1a24c44 100644 --- a/src/algorithms/dqns/dqn.jl +++ b/src/algorithms/dqns/dqn.jl @@ -87,16 +87,20 @@ end The state of the observation is assumed to have been stacked, if `!isnothing(stack_size)`. """ -(learner::DQNLearner)(env) = - env |> +function (learner::DQNLearner)(env) + probs = env |> get_state |> - x -> - Flux.unsqueeze(x, ndims(x) + 1) |> - x -> - send_to_device(device(learner.approximator), x) |> - learner.approximator |> - send_to_host |> - Flux.squeezebatch + x -> Flux.unsqueeze(x, ndims(x) + 1) |> + x -> send_to_device(device(learner.approximator), x) |> + learner.approximator |> + vec |> + send_to_host + + if ActionStyle(env) === FULL_ACTION_SET + probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env)) + end + probs +end function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory) length(t[:terminal]) < learner.min_replay_history && return @@ -124,10 +128,16 @@ function RLBase.update!(learner::DQNLearner, t::AbstractTrajectory) terminals = send_to_device(D, experience.terminals) next_states = send_to_device(D, experience.next_states) + target_q = Qₜ(next_states) + if haskey(t, :next_legal_actions_mask) + target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, t[:next_legal_actions_mask])) + end + + q′ = dropdims(maximum(target_q; dims = 1), dims = 1) + G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′ + gs = gradient(params(Q)) do q = Q(states)[actions] - q′ = dropdims(maximum(Qₜ(next_states); dims = 1), dims = 1) - G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′ loss = loss_func(G, q) ignore() do learner.loss = loss @@ -147,9 +157,20 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner) valid_ind_range = isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h)) inds = rand(learner.rng, valid_ind_range, n) + next_inds = inds .+ h + states = consecutive_view(t[:state], inds; n_stack = s) actions = consecutive_view(t[:action], inds) - next_states = consecutive_view(t[:state], inds .+ h; n_stack = s) + next_states = consecutive_view(t[:state], next_inds; n_stack = s) + + if haskey(t, :legal_actions_mask) + legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds) + next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], next_inds) + else + legal_actions_mask = nothing + next_legal_actions_mask = nothing + end + consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h) consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h) rewards, terminals = zeros(Float32, n), fill(false, n) @@ -167,9 +188,11 @@ function extract_experience(t::AbstractTrajectory, learner::DQNLearner) end ( states = states, + legal_actions_mask = legal_actions_mask, actions = actions, rewards = rewards, terminals = terminals, next_states = next_states, + next_legal_actions_mask = next_legal_actions_mask, ) end diff --git a/src/algorithms/dqns/iqn.jl b/src/algorithms/dqns/iqn.jl index a0816a0..7098bf1 100644 --- a/src/algorithms/dqns/iqn.jl +++ b/src/algorithms/dqns/iqn.jl @@ -156,7 +156,11 @@ function (learner::IQNLearner)(env) τ = rand(learner.device_rng, Float32, learner.K, 1) τₑₘ = embed(τ, learner.Nₑₘ) quantiles = learner.approximator(state, τₑₘ) - vec(mean(quantiles; dims = 2)) |> send_to_host + probs = vec(mean(quantiles; dims = 2)) |> send_to_host + if ActionStyle(env) === FULL_ACTION_SET + probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env)) + end + probs end embed(x, Nₑₘ) = cos.(Float32(π) .* (1:Nₑₘ) .* reshape(x, 1, :)) @@ -180,7 +184,13 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple) τ′ = rand(learner.device_rng, Float32, N′, batch_size) # TODO: support β distribution τₑₘ′ = embed(τ′, Nₑₘ) zₜ = Zₜ(s′, τₑₘ′) - aₜ = argmax(mean(zₜ, dims = 2), dims = 1) + avg_zₜ = mean(zₜ, dims = 2) + + if !isnothing(batch.next_legal_actions_mask) + avg_zₜ .+= typemin(eltype(avg_zₜ)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask)) + end + + aₜ = argmax(avg_zₜ, dims = 1) aₜ = aₜ .+ typeof(aₜ)(CartesianIndices((0, 0:N′-1, 0))) qₜ = reshape(zₜ[aₜ], :, batch_size) target = reshape(r, 1, batch_size) .+ learner.γ * reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast @@ -214,8 +224,7 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple) huber_loss ./ κ loss_per_quantile = reshape(sum(raw_loss; dims = 1), N, batch_size) loss_per_element = mean(loss_per_quantile; dims = 1) # use as priorities - loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size : - mean(loss_per_element) + loss = is_use_PER ? dot(vec(weights), vec(loss_per_element)) * 1 // batch_size : mean(loss_per_element) ignore() do # @assert all(loss_per_element .>= 0) is_use_PER && ( diff --git a/src/algorithms/dqns/prioritized_dqn.jl b/src/algorithms/dqns/prioritized_dqn.jl index 546f210..36f04aa 100644 --- a/src/algorithms/dqns/prioritized_dqn.jl +++ b/src/algorithms/dqns/prioritized_dqn.jl @@ -101,16 +101,20 @@ end The state of the observation is assumed to have been stacked, if `!isnothing(stack_size)`. """ -(learner::PrioritizedDQNLearner)(env) = - env |> +function (learner::PrioritizedDQNLearner)(env) + probs = env |> get_state |> - x -> - Flux.unsqueeze(x, ndims(x) + 1) |> - x -> - send_to_device(device(learner.approximator), x) |> - learner.approximator |> - send_to_host |> - Flux.squeezebatch + x -> Flux.unsqueeze(x, ndims(x) + 1) |> + x -> send_to_device(device(learner.approximator), x) |> + learner.approximator |> + vec |> + send_to_host + + if ActionStyle(env) === FULL_ACTION_SET + probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env)) + end + probs +end function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple) Q, Qₜ, γ, β, loss_func, update_horizon, batch_size = learner.approximator, @@ -132,11 +136,16 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple) weights ./= maximum(weights) weights = send_to_device(D, weights) + target_q = Qₜ(next_states) + if !isnothing(batch.next_legal_actions_mask) + target_q .+= typemin(eltype(target_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask)) + end + + q′ = dropdims(maximum(target_q; dims = 1), dims = 1) + G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′ + gs = gradient(params(Q)) do q = Q(states)[actions] - q′ = dropdims(maximum(Qₜ(next_states); dims = 1), dims = 1) - G = rewards .+ γ^update_horizon .* (1 .- terminals) .* q′ - batch_losses = loss_func(G, q) loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size ignore() do diff --git a/src/algorithms/dqns/rainbow.jl b/src/algorithms/dqns/rainbow.jl index d009d41..72b5385 100644 --- a/src/algorithms/dqns/rainbow.jl +++ b/src/algorithms/dqns/rainbow.jl @@ -126,39 +126,33 @@ function (learner::RainbowLearner)(env) state = Flux.unsqueeze(state, ndims(state) + 1) logits = learner.approximator(state) q = learner.support .* softmax(reshape(logits, :, learner.n_actions)) - # probs = vec(sum(q, dims=1)) .+ legal_action - vec(sum(q, dims = 1)) |> send_to_host + probs = vec(sum(q, dims = 1)) |> send_to_host + if ActionStyle(env) === FULL_ACTION_SET + probs .+= typemin(eltype(probs)) .* (1 .- get_legal_actions_mask(env)) + end + probs end function RLBase.update!(learner::RainbowLearner, batch::NamedTuple) - Q, - Qₜ, - γ, - β, - loss_func, - n_atoms, - n_actions, - support, - delta_z, - update_horizon, - batch_size = learner.approximator, - learner.target_approximator, - learner.γ, - learner.β_priority, - learner.loss_func, - learner.n_atoms, - learner.n_actions, - learner.support, - learner.delta_z, - learner.update_horizon, - learner.batch_size - + Q = learner.approximator + Qₜ = learner.target_approximator + γ = learner.γ + β = learner.β_priority + loss_func = learner.loss_func + n_atoms = learner.n_atoms + n_actions = learner.n_actions + support = learner.support + delta_z = learner.delta_z + update_horizon = learner.update_horizon + batch_size = learner.batch_size D = device(Q) - states, rewards, terminals, next_states = map( - x -> send_to_device(D, x), - (batch.states, batch.rewards, batch.terminals, batch.next_states), - ) + states = send_to_device(D, batch.states) + rewards = send_to_device(D, batch.rewards) + terminals = send_to_device(D, batch.terminals) + next_states = send_to_device(D, batch.next_states) + actions = CartesianIndex.(batch.actions, 1:batch_size) + target_support = reshape(rewards, 1, :) .+ (reshape(support, :, 1) * reshape((γ^update_horizon) .* (1 .- terminals), 1, :)) @@ -166,7 +160,9 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple) next_logits = Qₜ(next_states) next_probs = reshape(softmax(reshape(next_logits, n_atoms, :)), n_atoms, n_actions, :) next_q = reshape(sum(support .* next_probs, dims = 1), n_actions, :) - # next_q_argmax = argmax(cpu(next_q .+ next_legal_actions), dims=1) + if !isnothing(batch.next_legal_actions_mask) + next_q .+= typemin(eltype(next_q)) .* (1 .- send_to_device(D, batch.next_legal_actions_mask)) + end next_prob_select = select_best_probs(next_probs, next_q) target_distribution = project_distribution( @@ -178,18 +174,23 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple) learner.Vₘₐₓ, ) - updated_priorities = Vector{Float32}(undef, batch_size) - weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β) - weights ./= maximum(weights) - weights = send_to_device(D, weights) + is_use_PER = !isnothing(batch.priorities) # is use Prioritized Experience Replay + if is_use_PER + updated_priorities = Vector{Float32}(undef, batch_size) + weights = 1f0 ./ ((batch.priorities .+ 1f-10) .^ β) + weights ./= maximum(weights) + weights = send_to_device(D, weights) + end gs = gradient(Flux.params(Q)) do logits = reshape(Q(states), n_atoms, n_actions, :) select_logits = logits[:, actions] batch_losses = loss_func(select_logits, target_distribution) - loss = dot(vec(weights), vec(batch_losses)) * 1 // batch_size + loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size : mean(batch_losses) ignore() do - updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β)) + if is_use_PER + updated_priorities .= send_to_host(vec((batch_losses .+ 1f-10) .^ β)) + end learner.loss = loss end loss diff --git a/src/algorithms/policy_gradient/A2C.jl b/src/algorithms/policy_gradient/A2C.jl index 8405a0c..8fa1d8b 100644 --- a/src/algorithms/policy_gradient/A2C.jl +++ b/src/algorithms/policy_gradient/A2C.jl @@ -27,17 +27,28 @@ Base.@kwdef mutable struct A2CLearner{A<:ActorCritic} <: AbstractLearner loss::Float32 = 0.f0 end -(learner::A2CLearner)(env::MultiThreadEnv) = - learner.approximator.actor(send_to_device( +function (learner::A2CLearner)(env::MultiThreadEnv) + logits = learner.approximator.actor(send_to_device( device(learner.approximator), get_state(env), )) |> send_to_host + if ActionStyle(env[1]) === FULL_ACTION_SET + logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env)) + end + logits +end + function (learner::A2CLearner)(env) s = get_state(env) s = Flux.unsqueeze(s, ndims(s) + 1) s = send_to_device(device(learner.approximator), s) - learner.approximator.actor(s) |> vec |> send_to_host + logits = learner.approximator.actor(s) |> vec |> send_to_host + + if ActionStyle(env) === FULL_ACTION_SET + logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env)) + end + logits end function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory) @@ -54,9 +65,9 @@ function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory) w₁ = learner.actor_loss_weight w₂ = learner.critic_loss_weight w₃ = learner.entropy_loss_weight - - states = send_to_device(device(AC), states) - next_state = send_to_device(device(AC), next_state) + D = device(AC) + states = send_to_device(D, states) + next_state = send_to_device(D, next_state) states_flattened = flatten_batch(states) # (state_size..., n_thread * update_step) actions = flatten_batch(actions) @@ -70,11 +81,14 @@ function RLBase.update!(learner::A2CLearner, t::AbstractTrajectory) init = send_to_host(next_state_values), terminal = terminals, ) - gains = send_to_device(device(AC), gains) + gains = send_to_device(D, gains) ps = Flux.params(AC) gs = gradient(ps) do logits = AC.actor(states_flattened) + if haskey(t, :legal_actions_mask) + logits .+= typemin(eltype(logits)) .* (1 .- flatten_batch(send_to_device(D, t[:legal_actions_mask]))) + end probs = softmax(logits) log_probs = logsoftmax(logits) log_probs_select = log_probs[actions] diff --git a/src/algorithms/policy_gradient/ppo.jl b/src/algorithms/policy_gradient/ppo.jl index 40db89f..980498a 100644 --- a/src/algorithms/policy_gradient/ppo.jl +++ b/src/algorithms/policy_gradient/ppo.jl @@ -74,17 +74,28 @@ function PPOLearner(; ) end -(learner::PPOLearner)(env::MultiThreadEnv) = - learner.approximator.actor(send_to_device( +function (learner::PPOLearner)(env::MultiThreadEnv) + logits = learner.approximator.actor(send_to_device( device(learner.approximator), get_state(env), )) |> send_to_host + if ActionStyle(env[1]) === FULL_ACTION_SET + logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env)) + end + logits +end + function (learner::PPOLearner)(env) s = get_state(env) s = Flux.unsqueeze(s, ndims(s) + 1) s = send_to_device(device(learner.approximator), s) - learner.approximator.actor(s) |> vec |> send_to_host + logits = learner.approximator.actor(s) |> vec |> send_to_host + + if ActionStyle(env) === FULL_ACTION_SET + logits .+= typemin(eltype(logits)) .* (1 .- get_legal_actions_mask(env)) + end + logits end function RLBase.update!(learner::PPOLearner, t::PPOTrajectory) @@ -127,6 +138,9 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory) for i in 1:n_microbatches inds = rand_inds[(i-1)*microbatch_size+1:i*microbatch_size] s = send_to_device(D, select_last_dim(states_flatten, inds)) + if haskey(t, :legal_actions_mask) + lam = send_to_device(D, select_last_dim(flatten_batch(t[:legal_actions_mask]), inds)) + end a = vec(actions)[inds] r = send_to_device(D, vec(returns)[inds]) log_p = send_to_device(D, vec(action_log_probs)[inds]) @@ -136,6 +150,9 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory) gs = gradient(ps) do v′ = AC.critic(s) |> vec logit′ = AC.actor(s) + if haskey(t, :legal_actions_mask) + logit′ .+= typemin(eltype(logit′)) .* (1 .- lam) + end p′ = softmax(logit′) log_p′ = logsoftmax(logit′) log_p′ₐ = log_p′[CartesianIndex.(a, 1:length(a))]