Skip to content

Commit 4149579

Browse files
matteobettinivmoens
authored andcommitted
[Example] Multiagent examples: MAPPO-IPPO-MADDPG-IDDPG-IQL-QMIX-VDN (pytorch#1027)
Signed-off-by: Matteo Bettini <[email protected]>
1 parent 2fa358a commit 4149579

File tree

13 files changed

+1414
-0
lines changed

13 files changed

+1414
-0
lines changed

.circleci/unittest/linux_examples/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ dependencies:
2727
- mlflow
2828
- av
2929
- coverage
30+
- vmas

.circleci/unittest/linux_examples/scripts/run_test.sh

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,31 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
246246
collector.collector_device=cuda:0 \
247247
env.name=Pendulum-v1 \
248248
logger.backend=
249+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/mappo_ippo.py \
250+
collector.n_iters=2 \
251+
collector.frames_per_batch=200 \
252+
train.num_epochs=3 \
253+
train.minibatch_size=100 \
254+
logger.backend=
255+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/maddpg_iddpg.py \
256+
collector.n_iters=2 \
257+
collector.frames_per_batch=200 \
258+
train.num_epochs=3 \
259+
train.minibatch_size=100 \
260+
logger.backend=
261+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/iql.py \
262+
collector.n_iters=2 \
263+
collector.frames_per_batch=200 \
264+
train.num_epochs=3 \
265+
train.minibatch_size=100 \
266+
logger.backend=
267+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/multiagent/qmix_vdn.py \
268+
collector.n_iters=2 \
269+
collector.frames_per_batch=200 \
270+
train.num_epochs=3 \
271+
train.minibatch_size=100 \
272+
logger.backend=
273+
249274

250275
python .circleci/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100
251276

examples/multiagent/README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Multi-agent examples
2+
3+
In this folder we provide a set of multi-agent example scripts using the [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) simulator.
4+
5+
<p align="center">
6+
<img src="https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/img/marl_vmas.png" width="600px">
7+
</p>
8+
9+
<center><i>The MARL algorithms contained in the scripts of this folder run on three multi-robot tasks in VMAS.</i></center>
10+
11+
For more details on the experiment setup and the environments please refer to the corresponding section of the appendix in the [TorchRL paper](https://arxiv.org/abs/2306.00577).
12+
13+
## Using the scripts
14+
15+
### Install
16+
17+
First you need to install vmas and the dependencies of the scripts.
18+
19+
Install torchrl and tensordict following repo instructions.
20+
21+
Install vmas and dependencies:
22+
23+
```bash
24+
pip install vmas
25+
pip install wandb moviepy
26+
pip install hydra-core
27+
```
28+
29+
### Run
30+
31+
To run the scripts just execute the corresponding python file after having modified the corresponding config
32+
according to your needs.
33+
The config can be found in the .yaml file with the same name.
34+
35+
For example:
36+
```bash
37+
python mappo_ippo.py
38+
```
39+
40+
You can even change the config from the command line like:
41+
42+
```bash
43+
python mappo_ippo.py --m env.scenario_name=navigation
44+
```
45+
46+
### Computational demand
47+
The scripts are set up for collecting many frames, if your compute is limited, you can change the "frames_per_batch"
48+
and "num_epochs" parameters to reduce compute requirements.
49+
50+
### Script structure
51+
52+
The scripts are self-contained.
53+
This means that all the code you will need to look at is contained in the script file.
54+
No helper functions are used.
55+
56+
The structure of scripts follows this order:
57+
- Configuration dictionary for the script
58+
- Environment creation
59+
- Modules creation
60+
- Collector instantiation
61+
- Replay buffer instantiation
62+
- Loss module creation
63+
- Training loop (with inner minibatch loops)
64+
- Evaluation run (at the desired frequency)
65+
66+
Logging is done by default to wandb.
67+
The logging backend can be changed in the config files to one of "wandb", "tensorboard", "csv", "mlflow".
68+
69+
All the scripts follow the same on-policy training structure so that results can be compared across different algorithms.

examples/multiagent/iql.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import time
7+
8+
import hydra
9+
import torch
10+
11+
from tensordict.nn import TensorDictModule
12+
from torch import nn
13+
from torchrl.collectors import SyncDataCollector
14+
from torchrl.data import TensorDictReplayBuffer
15+
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
16+
from torchrl.data.replay_buffers.storages import LazyTensorStorage
17+
from torchrl.envs import RewardSum, TransformedEnv
18+
from torchrl.envs.libs.vmas import VmasEnv
19+
from torchrl.envs.utils import ExplorationType, set_exploration_type
20+
from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential
21+
from torchrl.modules.models.multiagent import MultiAgentMLP
22+
from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators
23+
from utils.logging import init_logging, log_evaluation, log_training
24+
25+
26+
def rendering_callback(env, td):
27+
env.frames.append(env.render(mode="rgb_array", agent_index_focus=None))
28+
29+
30+
@hydra.main(version_base="1.1", config_path=".", config_name="iql")
31+
def train(cfg: "DictConfig"): # noqa: F821
32+
# Device
33+
cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0"
34+
cfg.env.device = cfg.train.device
35+
36+
# Seeding
37+
torch.manual_seed(cfg.seed)
38+
39+
# Sampling
40+
cfg.env.vmas_envs = cfg.collector.frames_per_batch // cfg.env.max_steps
41+
cfg.collector.total_frames = cfg.collector.frames_per_batch * cfg.collector.n_iters
42+
cfg.buffer.memory_size = cfg.collector.frames_per_batch
43+
44+
# Create env and env_test
45+
env = VmasEnv(
46+
scenario=cfg.env.scenario_name,
47+
num_envs=cfg.env.vmas_envs,
48+
continuous_actions=False,
49+
max_steps=cfg.env.max_steps,
50+
device=cfg.env.device,
51+
seed=cfg.seed,
52+
# Scenario kwargs
53+
**cfg.env.scenario,
54+
)
55+
env = TransformedEnv(
56+
env,
57+
RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
58+
)
59+
60+
env_test = VmasEnv(
61+
scenario=cfg.env.scenario_name,
62+
num_envs=cfg.eval.evaluation_episodes,
63+
continuous_actions=False,
64+
max_steps=cfg.env.max_steps,
65+
device=cfg.env.device,
66+
seed=cfg.seed,
67+
# Scenario kwargs
68+
**cfg.env.scenario,
69+
)
70+
71+
# Policy
72+
net = MultiAgentMLP(
73+
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
74+
n_agent_outputs=env.action_spec.space.n,
75+
n_agents=env.n_agents,
76+
centralised=False,
77+
share_params=cfg.model.shared_parameters,
78+
device=cfg.train.device,
79+
depth=2,
80+
num_cells=256,
81+
activation_class=nn.Tanh,
82+
)
83+
module = TensorDictModule(
84+
net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")]
85+
)
86+
value_module = QValueModule(
87+
action_value_key=("agents", "action_value"),
88+
out_keys=[
89+
env.action_key,
90+
("agents", "action_value"),
91+
("agents", "chosen_action_value"),
92+
],
93+
spec=env.unbatched_action_spec,
94+
action_space=None,
95+
)
96+
qnet = SafeSequential(module, value_module)
97+
98+
qnet_explore = EGreedyWrapper(
99+
qnet,
100+
eps_init=0.3,
101+
eps_end=0,
102+
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
103+
action_key=env.action_key,
104+
spec=env.unbatched_action_spec[env.action_key],
105+
)
106+
107+
collector = SyncDataCollector(
108+
env,
109+
qnet_explore,
110+
device=cfg.env.device,
111+
storing_device=cfg.train.device,
112+
frames_per_batch=cfg.collector.frames_per_batch,
113+
total_frames=cfg.collector.total_frames,
114+
)
115+
116+
replay_buffer = TensorDictReplayBuffer(
117+
storage=LazyTensorStorage(cfg.buffer.memory_size, device=cfg.train.device),
118+
sampler=SamplerWithoutReplacement(),
119+
batch_size=cfg.train.minibatch_size,
120+
)
121+
122+
loss_module = DQNLoss(qnet, delay_value=True)
123+
loss_module.set_keys(
124+
action_value=("agents", "action_value"),
125+
action=env.action_key,
126+
value=("agents", "chosen_action_value"),
127+
reward=env.reward_key,
128+
)
129+
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
130+
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
131+
132+
optim = torch.optim.Adam(loss_module.parameters(), cfg.train.lr)
133+
134+
# Logging
135+
if cfg.logger.backend:
136+
model_name = ("Het" if not cfg.model.shared_parameters else "") + "IQL"
137+
logger = init_logging(cfg, model_name)
138+
139+
total_time = 0
140+
total_frames = 0
141+
sampling_start = time.time()
142+
for i, tensordict_data in enumerate(collector):
143+
print(f"\nIteration {i}")
144+
145+
sampling_time = time.time() - sampling_start
146+
147+
tensordict_data.set(
148+
("next", "done"),
149+
tensordict_data.get(("next", "done"))
150+
.unsqueeze(-1)
151+
.expand(tensordict_data.get(("next", env.reward_key)).shape),
152+
) # We need to expand the done to match the reward shape
153+
154+
current_frames = tensordict_data.numel()
155+
total_frames += current_frames
156+
data_view = tensordict_data.reshape(-1)
157+
replay_buffer.extend(data_view)
158+
159+
training_tds = []
160+
training_start = time.time()
161+
for _ in range(cfg.train.num_epochs):
162+
for _ in range(cfg.collector.frames_per_batch // cfg.train.minibatch_size):
163+
subdata = replay_buffer.sample()
164+
loss_vals = loss_module(subdata)
165+
training_tds.append(loss_vals.detach())
166+
167+
loss_value = loss_vals["loss"]
168+
169+
loss_value.backward()
170+
171+
total_norm = torch.nn.utils.clip_grad_norm_(
172+
loss_module.parameters(), cfg.train.max_grad_norm
173+
)
174+
training_tds[-1].set("grad_norm", total_norm.mean())
175+
176+
optim.step()
177+
optim.zero_grad()
178+
target_net_updater.step()
179+
180+
qnet_explore.step(frames=current_frames) # Update exploration annealing
181+
collector.update_policy_weights_()
182+
183+
training_time = time.time() - training_start
184+
185+
iteration_time = sampling_time + training_time
186+
total_time += iteration_time
187+
training_tds = torch.stack(training_tds)
188+
189+
# More logs
190+
if cfg.logger.backend:
191+
log_training(
192+
logger,
193+
training_tds,
194+
tensordict_data,
195+
sampling_time,
196+
training_time,
197+
total_time,
198+
i,
199+
current_frames,
200+
total_frames,
201+
step=i,
202+
)
203+
204+
if (
205+
cfg.eval.evaluation_episodes > 0
206+
and i % cfg.eval.evaluation_interval == 0
207+
and cfg.logger.backend
208+
):
209+
evaluation_start = time.time()
210+
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
211+
env_test.frames = []
212+
rollouts = env_test.rollout(
213+
max_steps=cfg.env.max_steps,
214+
policy=qnet,
215+
callback=rendering_callback,
216+
auto_cast_to_device=True,
217+
break_when_any_done=False,
218+
# We are running vectorized evaluation we do not want it to stop when just one env is done
219+
)
220+
221+
evaluation_time = time.time() - evaluation_start
222+
223+
log_evaluation(logger, rollouts, env_test, evaluation_time, step=i)
224+
225+
if cfg.logger.backend == "wandb":
226+
logger.experiment.log({}, commit=True)
227+
sampling_start = time.time()
228+
229+
230+
if __name__ == "__main__":
231+
train()

examples/multiagent/iql.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
seed: 0
2+
3+
env:
4+
max_steps: 100
5+
scenario_name: "balance"
6+
scenario:
7+
n_agents: 3
8+
device: ??? # These values will be populated dynamically
9+
vmas_envs: ???
10+
11+
model:
12+
shared_parameters: True
13+
14+
collector:
15+
frames_per_batch: 60_000 # Frames sampled each sampling iteration
16+
n_iters: 500 # Number of sampling/training iterations
17+
total_frames: ???
18+
19+
buffer:
20+
memory_size: ???
21+
22+
loss:
23+
gamma: 0.9
24+
tau: 0.005 # For target net
25+
26+
train:
27+
num_epochs: 45 # optimization steps per batch of data collected
28+
minibatch_size: 4096 # size of minibatches used in each epoch
29+
lr: 5e-5
30+
max_grad_norm: 40.0
31+
device: ???
32+
33+
eval:
34+
evaluation_interval: 20
35+
evaluation_episodes: 200
36+
37+
logger:
38+
backend: wandb # Delete to remove logging

0 commit comments

Comments
 (0)