From 92a6103676653f9297a596be6edfa342701881c5 Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 9 Aug 2021 16:28:56 +0800 Subject: [PATCH 1/4] add maddpg --- .../src/algorithms/policy_gradient/maddpg.jl | 171 ++++++++++++++++++ .../policy_gradient/policy_gradient.jl | 1 + 2 files changed, 172 insertions(+) create mode 100644 src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl new file mode 100644 index 000000000..3d4060dcc --- /dev/null +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl @@ -0,0 +1,171 @@ +export MADDPGManager + +""" + MADDPGManager(; agents::Dict{<:Any, <:Agent}, args...) +Multi-agent Deep Deterministic Policy Gradient(MADDPG) implemented in Julia. Here only works for simultaneous games whose action space is discrete. +See the paper https://arxiv.org/abs/1706.02275 for more details. + +# Keyword arguments +- `agents::Dict{<:Any, <:Agent{<:DDPGPolicy, <:AbstractTrajectory}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network. +- `batch_size::Int` +- `update_freq::Int` +- `update_step::Int`, count the step. +- `rng::AbstractRNG`. +""" +mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory} <: AbstractPolicy + agents::Dict{<:Any, <:Agent{<:P, <:T}} + batch_size::Int + update_freq::Int + update_step::Int + rng::AbstractRNG +end + +# for simultaneous game with a discrete action space. +function (π::MADDPGManager)(env::AbstractEnv) + while current_player(env) == chance_player(env) + env |> legal_action_space |> rand |> env + end + Dict((player, ceil(agent.policy(env))) for (player, agent) in π.agents) +end + +function (π::MADDPGManager)(::PreEpisodeStage, ::AbstractEnv) + for (_, agent) in π.agents + if length(agent.trajectory) > 0 + pop!(agent.trajectory[:state]) + pop!(agent.trajectory[:action]) + if haskey(agent.trajectory, :legal_actions_mask) + pop!(agent.trajectory[:legal_actions_mask]) + end + end + end +end + +function (π::MADDPGManager)(::PreActStage, env::AbstractEnv, actions) + # update each agent's trajectory + for (player, agent) in π.agents + push!(agent.trajectory[:state], state(env, player)) + push!(agent.trajectory[:action], actions[player]) + if haskey(agent.trajectory, :legal_actions_mask) + lasm = legal_action_space_mask(env, player) + push!(agent.trajectory[:legal_actions_mask], lasm) + end + end + + # update policy + update!(π) +end + +function (π::MADDPGManager)(::PostActStage, env::AbstractEnv) + for (player, agent) in π.agents + push!(agent.trajectory[:reward], reward(env, player)) + push!(agent.trajectory[:terminal], is_terminated(env)) + end +end + +function (π::MADDPGManager)(::PostEpisodeStage, env::AbstractEnv) + # collect state and dummy action to each agent's trajectory + for (player, agent) in π.agents + push!(agent.trajectory[:state], state(env, player)) + push!(agent.trajectory[:action], rand(action_space(env))) + if haskey(agent.trajectory, :legal_actions_mask) + lasm = legal_action_space_mask(env, player) + push!(agent.trajectory[:legal_actions_mask], lasm) + end + end + + # update policy + update!(π) +end + +# update policy +function RLBase.update!(π::MADDPGManager) + π.update_step += 1 + π.update_step % π.update_freq == 0 || return + + for (_, agent) in π.agents + length(agent.trajectory) > agent.policy.update_after || return + length(agent.trajectory) > π.batch_size || return + end + + # get trainning data + temp_player = rand(keys(π.agents)) + t = π.agents[temp_player].trajectory + inds = rand(π.rng, 1:length(t), π.batch_size) + batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}(π.batch_size), agent.trajectory, inds)) + for (player, agent) in π.agents) + + # get s, a, s′ for critic + s = vcat((batches[player][1] for (player, _) in π.agents)...) + a = vcat((batches[player][2] for (player, _) in π.agents)...) + s′ = vcat((batches[player][5] for (player, _) in π.agents)...) + + # for training behavior_actor + mu_actions = vcat( + (( + batches[player][1] |> # get personal state information + x -> send_to_device(device(agent.policy.behavior_actor), x) |> + agent.policy.behavior_actor |> send_to_host + ) for (player, agent) in π.agents)... + ) + # for training behavior_critic + new_actions = vcat( + (( + batches[player][5] |> # batch[5] get new_state information + x -> send_to_device(device(agent.policy.target_actor), x) |> + agent.policy.target_actor |> send_to_host + ) for (player, agent) in π.agents)... + ) + + for (player, agent) in π.agents + p = agent.policy + A = p.behavior_actor + C = p.behavior_critic + Aₜ = p.target_actor + Cₜ = p.target_critic + + γ = p.γ + ρ = p.ρ + + _device(x) = send_to_device(device(A), x) + + # Note that here default A, C, Aₜ, Cₜ on the same device. + s, a, s′ = _device((s, a, s′)) + mu_actions = _device(mu_actions) + new_actions = _device(new_actions) + r = _device(batches[player][:reward]) + t = _device(batches[player][:terminal]) + + qₜ = Cₜ(vcat(s′, new_actions)) |> vec + y = r .+ γ .* (1 .- t) .* qₜ + + gs1 = gradient(Flux.params(C)) do + q = C(vcat(s, a)) |> vec + loss = mean((y .- q) .^ 2) + ignore() do + p.critic_loss = loss + end + loss + end + + update!(C, gs1) + + gs2 = gradient(Flux.params(A)) do + loss = -mean(C(vcat(s, mu_actions))) + ignore() do + p.actor_loss = loss + end + loss + end + + update!(A, gs2) + + # polyak averaging + for (dest, src) in zip(Flux.params([Aₜ, Cₜ]), Flux.params([A, C])) + dest .= ρ .* dest .+ (1 - ρ) .* src + end + + s, a, s′ = send_to_host((s, a, s′)) + mu_actions = send_to_host(mu_actions) + new_actions = send_to_host(new_actions) + end +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl index c7b988c87..c0201cb47 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/policy_gradient.jl @@ -7,3 +7,4 @@ include("MAC.jl") include("ddpg.jl") include("td3.jl") include("sac.jl") +include("maddpg.jl") From 0b724ea85657e94ea823116ea305282d9e5dbf1e Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 9 Aug 2021 16:29:13 +0800 Subject: [PATCH 2/4] add experiment --- .../JuliaRL_MADDPG_KuhnPoker.jl | 122 ++++++++++++++++++ .../experiments/Policy Gradient/config.json | 1 + 2 files changed, 123 insertions(+) create mode 100644 docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl new file mode 100644 index 000000000..00b0df196 --- /dev/null +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl @@ -0,0 +1,122 @@ +# --- +# title: JuliaRL\_MADDPG\_KuhnPoker +# cover: assets/JuliaRL_MADDPG_KuhnPoker.png +# description: MADDPG applied to KuhnPoker +# date: 2021-08-09 +# author: "[Peter Chen](https://github.com/peterchen96)" +# --- + +#+ tangle=true +using ReinforcementLearning +using StableRNGs +using Flux +using IntervalSets + +mutable struct ResultNEpisode <: AbstractHook + eval_freq::Int + episode_counter::Int + episode::Vector{Int} + results::Vector{Float64} +end + +function (hook::ResultNEpisode)(::PostEpisodeStage, policy, env) + hook.episode_counter += 1 + if hook.episode_counter % hook.eval_freq == 0 + push!(hook.episode, hook.episode_counter) + push!(hook.results, reward(env, 1)) + end +end + +function RL.Experiment( + ::Val{:JuliaRL}, + ::Val{:MADDPG}, + ::Val{:KuhnPoker}, + ::Nothing; + seed=123, +) + rng = StableRNG(seed) + env = KuhnPokerEnv() + wrapped_env = ActionTransformedEnv( + StateTransformedEnv( + env; + state_mapping = s -> [findfirst(==(s), state_space(env))], + state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)] + ), + ## add a dummy action for the other agent. + action_mapping = x -> length(x) == 1 ? x : Int(x[current_player(env)] + 1), + ) + ns, na = 1, 1 + n_players = 2 + + init = glorot_uniform(rng) + + create_actor() = Chain( + Dense(ns, 64, relu; init = init), + Dense(64, 64, relu; init = init), + Dense(64, na, tanh; init = init), + ) + + create_critic() = Chain( + Dense(n_players * ns + n_players * na, 64, relu; init = init), + Dense(64, 64, relu; init = init), + Dense(64, 1; init = init), + ) + + agent = Agent( + policy = DDPGPolicy( + behavior_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + behavior_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + target_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + target_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + γ = 0.99f0, + ρ = 0.995f0, + na = na, + start_steps = 1000, + start_policy = RandomPolicy(-0.9..0.9; rng = rng), + update_after = 1000, + act_limit = 0.9, + act_noise = 0.1, + rng = rng, + ), + trajectory = CircularArraySARTTrajectory( + capacity = 10000, # replay buffer capacity + state = Vector{Int} => (ns, ), + action = Float32 => (na, ), + ), + ) + + agents = MADDPGManager( + Dict((player, deepcopy(agent)) + for player in players(env) if player != chance_player(env)), + 128, # batch_size + 128, # update_freq + 0, # step_counter + rng + ) + + stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI")) + hook = ResultNEpisode(1000, 0, [], []) + Experiment(agents, wrapped_env, stop_condition, hook, "# run MADDPG on KuhnPokerEnv") +end + +#+ tangle=false +using Plots +ex = E`JuliaRL_MADDPG_KuhnPoker` +run(ex) +scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="reward of player 1") + +savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide + +# ![](assets/JuliaRL_MADDPG_KuhnPoker.png) \ No newline at end of file diff --git a/docs/experiments/experiments/Policy Gradient/config.json b/docs/experiments/experiments/Policy Gradient/config.json index db4b2a4fc..03b2c5be5 100644 --- a/docs/experiments/experiments/Policy Gradient/config.json +++ b/docs/experiments/experiments/Policy Gradient/config.json @@ -4,6 +4,7 @@ "JuliaRL_A2C_CartPole.jl", "JuliaRL_A2CGAE_CartPole.jl", "JuliaRL_DDPG_Pendulum.jl", + "JuliaRL_MADDPG_KuhnPoker.jl", "JuliaRL_MAC_CartPole.jl", "JuliaRL_PPO_CartPole.jl", "JuliaRL_PPO_Pendulum.jl", From d5ed9baf74b5c97f0c75e4b24761567c938580ac Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 9 Aug 2021 19:35:54 +0800 Subject: [PATCH 3/4] update cspell.json --- .cspell/cspell.json | 5 +++-- .../experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl | 4 ++-- .../src/algorithms/policy_gradient/maddpg.jl | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.cspell/cspell.json b/.cspell/cspell.json index b1606e34d..fa2572634 100644 --- a/.cspell/cspell.json +++ b/.cspell/cspell.json @@ -120,7 +120,8 @@ "Norouzi", "gzopen", "turbulences", - "Decompressor" + "Decompressor", + "MADDPG" ], "ignoreWords": [], "minWordLength": 5, @@ -143,4 +144,4 @@ "\\{%.*%\\}", // liquid syntax "/^\\s*```[\\s\\S]*?^\\s*```/gm" // Another attempt at markdown code blocks. https://github.com/streetsidesoftware/vscode-spell-checker/issues/202#issuecomment-377477473 ] -} \ No newline at end of file +} diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl index 00b0df196..171dd500b 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl @@ -106,7 +106,7 @@ function RL.Experiment( rng ) - stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(10_000, is_show_progress=!haskey(ENV, "CI")) hook = ResultNEpisode(1000, 0, [], []) Experiment(agents, wrapped_env, stop_condition, hook, "# run MADDPG on KuhnPokerEnv") end @@ -119,4 +119,4 @@ scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel=" savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide -# ![](assets/JuliaRL_MADDPG_KuhnPoker.png) \ No newline at end of file +# ![](assets/JuliaRL_MADDPG_KuhnPoker.png) diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl index 3d4060dcc..21776b46a 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl @@ -87,7 +87,7 @@ function RLBase.update!(π::MADDPGManager) length(agent.trajectory) > π.batch_size || return end - # get trainning data + # get training data temp_player = rand(keys(π.agents)) t = π.agents[temp_player].trajectory inds = rand(π.rng, 1:length(t), π.batch_size) From 2fc2ee061bf78bd9ed2e9a23ec7e1270abb2ba84 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 12 Aug 2021 01:24:59 +0800 Subject: [PATCH 4/4] update the algo --- .../JuliaRL_MADDPG_KuhnPoker.jl | 69 +++++++-------- .../src/algorithms/policy_gradient/maddpg.jl | 84 +++++++------------ 2 files changed, 64 insertions(+), 89 deletions(-) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl index 171dd500b..9700c5057 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl @@ -62,51 +62,52 @@ function RL.Experiment( Dense(64, 1; init = init), ) - agent = Agent( - policy = DDPGPolicy( - behavior_actor = NeuralNetworkApproximator( - model = create_actor(), - optimizer = ADAM(), - ), - behavior_critic = NeuralNetworkApproximator( - model = create_critic(), - optimizer = ADAM(), - ), - target_actor = NeuralNetworkApproximator( - model = create_actor(), - optimizer = ADAM(), - ), - target_critic = NeuralNetworkApproximator( - model = create_critic(), - optimizer = ADAM(), - ), - γ = 0.99f0, - ρ = 0.995f0, - na = na, - start_steps = 1000, - start_policy = RandomPolicy(-0.9..0.9; rng = rng), - update_after = 1000, - act_limit = 0.9, - act_noise = 0.1, - rng = rng, + + policy = DDPGPolicy( + behavior_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), ), - trajectory = CircularArraySARTTrajectory( - capacity = 10000, # replay buffer capacity - state = Vector{Int} => (ns, ), - action = Float32 => (na, ), + behavior_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), ), + target_actor = NeuralNetworkApproximator( + model = create_actor(), + optimizer = ADAM(), + ), + target_critic = NeuralNetworkApproximator( + model = create_critic(), + optimizer = ADAM(), + ), + γ = 0.99f0, + ρ = 0.995f0, + na = na, + start_steps = 1000, + start_policy = RandomPolicy(-0.9..0.9; rng = rng), + update_after = 1000, + act_limit = 0.9, + act_noise = 0.1, + rng = rng, + ) + trajectory = CircularArraySARTTrajectory( + capacity = 10000, # replay buffer capacity + state = Vector{Int} => (ns, ), + action = Float32 => (na, ), ) agents = MADDPGManager( - Dict((player, deepcopy(agent)) - for player in players(env) if player != chance_player(env)), + Dict((player, Agent( + policy = NamedPolicy(player, deepcopy(policy)), + trajectory = deepcopy(trajectory), + )) for player in players(env) if player != chance_player(env)), 128, # batch_size 128, # update_freq 0, # step_counter rng ) - stop_condition = StopAfterEpisode(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI")) hook = ResultNEpisode(1000, 0, [], []) Experiment(agents, wrapped_env, stop_condition, hook, "# run MADDPG on KuhnPokerEnv") end diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl index 21776b46a..84a823a14 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl @@ -6,14 +6,14 @@ Multi-agent Deep Deterministic Policy Gradient(MADDPG) implemented in Julia. Her See the paper https://arxiv.org/abs/1706.02275 for more details. # Keyword arguments -- `agents::Dict{<:Any, <:Agent{<:DDPGPolicy, <:AbstractTrajectory}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network. +- `agents::Dict{<:Any, <:NamedPolicy{<:Agent{<:DDPGPolicy, <:AbstractTrajectory}, <:Any}}`, here each agent collects its own information. While updating the policy, each `critic` will assemble all agents' trajectory to update its own network. - `batch_size::Int` - `update_freq::Int` - `update_step::Int`, count the step. - `rng::AbstractRNG`. """ -mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory} <: AbstractPolicy - agents::Dict{<:Any, <:Agent{<:P, <:T}} +mutable struct MADDPGManager{P<:DDPGPolicy, T<:AbstractTrajectory, N<:Any} <: AbstractPolicy + agents::Dict{<:N, <:Agent{<:NamedPolicy{<:P, <:N}, <:T}} batch_size::Int update_freq::Int update_step::Int @@ -28,49 +28,27 @@ function (π::MADDPGManager)(env::AbstractEnv) Dict((player, ceil(agent.policy(env))) for (player, agent) in π.agents) end -function (π::MADDPGManager)(::PreEpisodeStage, ::AbstractEnv) +function (π::MADDPGManager)(stage::Union{PreEpisodeStage, PostActStage}, env::AbstractEnv) + # only need to update trajectory. for (_, agent) in π.agents - if length(agent.trajectory) > 0 - pop!(agent.trajectory[:state]) - pop!(agent.trajectory[:action]) - if haskey(agent.trajectory, :legal_actions_mask) - pop!(agent.trajectory[:legal_actions_mask]) - end - end + update!(agent.trajectory, agent.policy, env, stage) end end -function (π::MADDPGManager)(::PreActStage, env::AbstractEnv, actions) - # update each agent's trajectory +function (π::MADDPGManager)(stage::PreActStage, env::AbstractEnv, actions) + # update each agent's trajectory. for (player, agent) in π.agents - push!(agent.trajectory[:state], state(env, player)) - push!(agent.trajectory[:action], actions[player]) - if haskey(agent.trajectory, :legal_actions_mask) - lasm = legal_action_space_mask(env, player) - push!(agent.trajectory[:legal_actions_mask], lasm) - end + update!(agent.trajectory, agent.policy, env, stage, actions[player]) end # update policy update!(π) end -function (π::MADDPGManager)(::PostActStage, env::AbstractEnv) - for (player, agent) in π.agents - push!(agent.trajectory[:reward], reward(env, player)) - push!(agent.trajectory[:terminal], is_terminated(env)) - end -end - -function (π::MADDPGManager)(::PostEpisodeStage, env::AbstractEnv) - # collect state and dummy action to each agent's trajectory - for (player, agent) in π.agents - push!(agent.trajectory[:state], state(env, player)) - push!(agent.trajectory[:action], rand(action_space(env))) - if haskey(agent.trajectory, :legal_actions_mask) - lasm = legal_action_space_mask(env, player) - push!(agent.trajectory[:legal_actions_mask], lasm) - end +function (π::MADDPGManager)(stage::PostEpisodeStage, env::AbstractEnv) + # collect state and a dummy action to each agent's trajectory here. + for (_, agent) in π.agents + update!(agent.trajectory, agent.policy, env, stage) end # update policy @@ -83,41 +61,41 @@ function RLBase.update!(π::MADDPGManager) π.update_step % π.update_freq == 0 || return for (_, agent) in π.agents - length(agent.trajectory) > agent.policy.update_after || return + length(agent.trajectory) > agent.policy.policy.update_after || return length(agent.trajectory) > π.batch_size || return end # get training data - temp_player = rand(keys(π.agents)) + temp_player = collect(keys(π.agents))[1] t = π.agents[temp_player].trajectory inds = rand(π.rng, 1:length(t), π.batch_size) batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}(π.batch_size), agent.trajectory, inds)) for (player, agent) in π.agents) # get s, a, s′ for critic - s = vcat((batches[player][1] for (player, _) in π.agents)...) - a = vcat((batches[player][2] for (player, _) in π.agents)...) - s′ = vcat((batches[player][5] for (player, _) in π.agents)...) + s = Flux.stack((batches[player][:state] for (player, _) in π.agents), 1) + a = Flux.stack((batches[player][:action] for (player, _) in π.agents), 1) + s′ = Flux.stack((batches[player][:next_state] for (player, _) in π.agents), 1) # for training behavior_actor - mu_actions = vcat( + mu_actions = Flux.stack( (( - batches[player][1] |> # get personal state information - x -> send_to_device(device(agent.policy.behavior_actor), x) |> - agent.policy.behavior_actor |> send_to_host - ) for (player, agent) in π.agents)... + batches[player][:state] |> # get personal state information + x -> send_to_device(device(agent.policy.policy.behavior_actor), x) |> + agent.policy.policy.behavior_actor |> send_to_host + ) for (player, agent) in π.agents), 1 ) # for training behavior_critic - new_actions = vcat( + new_actions = Flux.stack( (( - batches[player][5] |> # batch[5] get new_state information - x -> send_to_device(device(agent.policy.target_actor), x) |> - agent.policy.target_actor |> send_to_host - ) for (player, agent) in π.agents)... + batches[player][:next_state] |> # get personal next_state information + x -> send_to_device(device(agent.policy.policy.target_actor), x) |> + agent.policy.policy.target_actor |> send_to_host + ) for (player, agent) in π.agents), 1 ) for (player, agent) in π.agents - p = agent.policy + p = agent.policy.policy # get DDPGPolicy struct A = p.behavior_actor C = p.behavior_critic Aₜ = p.target_actor @@ -163,9 +141,5 @@ function RLBase.update!(π::MADDPGManager) for (dest, src) in zip(Flux.params([Aₜ, Cₜ]), Flux.params([A, C])) dest .= ρ .* dest .+ (1 - ρ) .* src end - - s, a, s′ = send_to_host((s, a, s′)) - mu_actions = send_to_host(mu_actions) - new_actions = send_to_host(new_actions) end end