Skip to content

Commit c57d1d9

Browse files
committed
Add group_map support to MLAgents wrappers
1 parent ec04c35 commit c57d1d9

File tree

2 files changed

+93
-39
lines changed

2 files changed

+93
-39
lines changed

test/test_libs.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3950,43 +3950,55 @@ def test_chance_not_implemented(self):
39503950
class TestUnityMLAgents:
39513951
@mock.patch("mlagents_envs.env_utils.launch_executable")
39523952
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
3953-
def test_env(self, mock_communicator, mock_launcher):
3953+
@pytest.mark.parametrize(
3954+
"group_map",
3955+
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
3956+
)
3957+
def test_env(self, mock_communicator, mock_launcher, group_map):
39543958
from mlagents_envs.mock_communicator import MockCommunicator
39553959

39563960
mock_communicator.return_value = MockCommunicator(
39573961
discrete_action=False, visual_inputs=0
39583962
)
3959-
env = UnityMLAgentsEnv(" ")
3963+
env = UnityMLAgentsEnv(" ", group_map=group_map)
39603964
try:
39613965
check_env_specs(env)
39623966
finally:
39633967
env.close()
39643968

39653969
@mock.patch("mlagents_envs.env_utils.launch_executable")
39663970
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
3967-
def test_wrapper(self, mock_communicator, mock_launcher):
3971+
@pytest.mark.parametrize(
3972+
"group_map",
3973+
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
3974+
)
3975+
def test_wrapper(self, mock_communicator, mock_launcher, group_map):
39683976
from mlagents_envs.environment import UnityEnvironment
39693977
from mlagents_envs.mock_communicator import MockCommunicator
39703978

39713979
mock_communicator.return_value = MockCommunicator(
39723980
discrete_action=False, visual_inputs=0
39733981
)
3974-
env = UnityMLAgentsWrapper(UnityEnvironment(" "))
3982+
env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map)
39753983
try:
39763984
check_env_specs(env)
39773985
finally:
39783986
env.close()
39793987

39803988
@mock.patch("mlagents_envs.env_utils.launch_executable")
39813989
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
3982-
def test_rollout(self, mock_communicator, mock_launcher):
3990+
@pytest.mark.parametrize(
3991+
"group_map",
3992+
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
3993+
)
3994+
def test_rollout(self, mock_communicator, mock_launcher, group_map):
39833995
from mlagents_envs.environment import UnityEnvironment
39843996
from mlagents_envs.mock_communicator import MockCommunicator
39853997

39863998
mock_communicator.return_value = MockCommunicator(
39873999
discrete_action=False, visual_inputs=0
39884000
)
3989-
env = UnityMLAgentsWrapper(UnityEnvironment(" "))
4001+
env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map)
39904002
try:
39914003
env.rollout(
39924004
max_steps=500, break_when_any_done=False, break_when_all_done=False
@@ -4031,10 +4043,15 @@ def test_with_editor(self):
40314043
5,
40324044
)
40334045
@pytest.mark.parametrize("registered_name", _mlagents_registered_envs)
4034-
def test_registered_envs(self, registered_name):
4046+
@pytest.mark.parametrize(
4047+
"group_map",
4048+
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
4049+
)
4050+
def test_registered_envs(self, registered_name, group_map):
40354051
env = UnityMLAgentsEnv(
40364052
registered_name=registered_name,
40374053
no_graphics=True,
4054+
group_map=group_map,
40384055
)
40394056
try:
40404057
check_env_specs(env)

torchrl/envs/libs/unity_mlagents.py

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations
77

88
import importlib.util
9-
from typing import Dict, Optional
9+
from typing import Dict, List, Optional
1010

1111
import torch
1212
from tensordict import TensorDict, TensorDictBase
@@ -20,7 +20,7 @@
2020
Unbounded,
2121
)
2222
from torchrl.envs.common import _EnvWrapper
23-
from torchrl.envs.utils import _classproperty, check_marl_grouping
23+
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
2424

2525
_has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None
2626

@@ -56,6 +56,11 @@ class UnityMLAgentsWrapper(_EnvWrapper):
5656
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
5757
for envs to be ``done`` just after :meth:`~.reset` is called.
5858
Defaults to ``False``.
59+
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
60+
group agents in tensordicts for input/output. See
61+
:class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
62+
specified, agents are grouped according to the group ID given by the
63+
Unity environment. Defaults to ``None``.
5964
categorical_actions (bool, optional): if ``True``, categorical specs
6065
will be converted to the TorchRL equivalent
6166
(:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
@@ -92,12 +97,14 @@ def __init__(
9297
self,
9398
env=None,
9499
*,
100+
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
95101
categorical_actions: bool = False,
96102
**kwargs,
97103
):
98104
if env is not None:
99105
kwargs["env"] = env
100106

107+
self.group_map = group_map
101108
self.categorical_actions = categorical_actions
102109
super().__init__(**kwargs)
103110

@@ -118,12 +125,11 @@ def _build_env(self, env, requires_grad: bool = False, **kwargs):
118125
def _init_env(self):
119126
self._update_action_mask()
120127

121-
# Creates a group map where agents are grouped by their group_id.
122-
def _make_group_map(self, env):
123-
group_map = {}
124-
agent_names = []
128+
# Creates a group map where agents are grouped by the group_id given by the
129+
# Unity environment.
130+
def _collect_agents(self, env):
125131
agent_name_to_behavior_map = {}
126-
agent_name_to_id_map = {}
132+
agent_name_to_group_id_map = {}
127133

128134
for steps_idx in [0, 1]:
129135
for behavior in env.behavior_specs.keys():
@@ -134,22 +140,41 @@ def _make_group_map(self, env):
134140

135141
for agent_id, group_id in zip(agent_ids, group_ids):
136142
agent_name = f"agent_{agent_id}"
137-
group_name = f"group_{group_id}"
138-
if group_name not in group_map.keys():
139-
group_map[group_name] = []
140-
if agent_name in group_map[group_name]:
143+
if agent_name in agent_name_to_behavior_map:
141144
# Sometimes in an MLAgents environment, an agent may
142145
# show up in both the decision steps and the terminal
143146
# steps. When that happens, just skip the duplicate.
144147
assert is_terminal
145148
continue
146-
group_map[group_name].append(agent_name)
147-
agent_names.append(agent_name)
148149
agent_name_to_behavior_map[agent_name] = behavior
149-
agent_name_to_id_map[agent_name] = agent_id
150+
agent_name_to_group_id_map[agent_name] = group_id
151+
152+
return (
153+
agent_name_to_behavior_map,
154+
agent_name_to_group_id_map,
155+
)
150156

151-
check_marl_grouping(group_map, agent_names)
152-
return group_map, agent_name_to_behavior_map, agent_name_to_id_map
157+
# Creates a group map where agents are grouped by their group_id.
158+
def _make_default_group_map(self, agent_name_to_group_id_map):
159+
group_map = {}
160+
for agent_name, group_id in agent_name_to_group_id_map.items():
161+
group_name = f"group_{group_id}"
162+
if group_name not in group_map:
163+
group_map[group_name] = []
164+
group_map[group_name].append(agent_name)
165+
return group_map
166+
167+
def _make_group_map(self, group_map, agent_name_to_group_id_map):
168+
if group_map is None:
169+
group_map = self._make_default_group_map(agent_name_to_group_id_map)
170+
elif isinstance(group_map, MarlGroupMapType):
171+
group_map = group_map.get_group_map(agent_name_to_group_id_map.keys())
172+
check_marl_grouping(group_map, agent_name_to_group_id_map.keys())
173+
agent_name_to_group_name_map = {}
174+
for group_name, agents in group_map.items():
175+
for agent_name in agents:
176+
agent_name_to_group_name_map[agent_name] = group_name
177+
return group_map, agent_name_to_group_name_map
153178

154179
def _make_specs(
155180
self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821
@@ -163,10 +188,13 @@ def _make_specs(
163188
# will need to detect changes to the behaviors and agents on each step.
164189
env.reset()
165190
(
166-
self.group_map,
167191
self.agent_name_to_behavior_map,
168-
self.agent_name_to_id_map,
169-
) = self._make_group_map(env)
192+
self.agent_name_to_group_id_map,
193+
) = self._collect_agents(env)
194+
195+
(self.group_map, self.agent_name_to_group_name_map) = self._make_group_map(
196+
self.group_map, self.agent_name_to_group_id_map
197+
)
170198

171199
action_spec = {}
172200
observation_spec = {}
@@ -257,17 +285,21 @@ def _set_seed(self, seed):
257285
if seed is not None:
258286
raise NotImplementedError("This environment has no seed.")
259287

260-
def _check_agent_exists(self, agent_name, group_name):
261-
if (
262-
group_name not in self.full_action_spec.keys()
263-
or agent_name not in self.full_action_spec[group_name].keys()
264-
):
288+
def _check_agent_exists(self, agent_name, group_id):
289+
if agent_name not in self.agent_name_to_group_id_map:
265290
raise RuntimeError(
266291
(
267292
"Unity environment added a new agent. This is not yet "
268293
"supported in torchrl."
269294
)
270295
)
296+
if self.agent_name_to_group_id_map[agent_name] != group_id:
297+
raise RuntimeError(
298+
(
299+
"Unity environment changed the group of an agent. This "
300+
"is not yet supported in torchrl."
301+
)
302+
)
271303

272304
def _update_action_mask(self):
273305
for behavior, behavior_spec in self._env.behavior_specs.items():
@@ -290,8 +322,8 @@ def _update_action_mask(self):
290322
steps.agent_id, steps.group_id, combined_action_mask
291323
):
292324
agent_name = f"agent_{agent_id}"
293-
group_name = f"group_{group_id}"
294-
self._check_agent_exists(agent_name, group_name)
325+
self._check_agent_exists(agent_name, group_id)
326+
group_name = self.agent_name_to_group_name_map[agent_name]
295327
self.full_action_spec[
296328
group_name, agent_name, "discrete_action"
297329
].update_mask(agent_action_mask)
@@ -305,8 +337,8 @@ def _make_td_out(self, tensordict_in, is_reset=False):
305337
zip(steps.agent_id, steps.group_id)
306338
):
307339
agent_name = f"agent_{agent_id}"
308-
group_name = f"group_{group_id}"
309-
self._check_agent_exists(agent_name, group_name)
340+
self._check_agent_exists(agent_name, group_id)
341+
group_name = self.agent_name_to_group_name_map[agent_name]
310342
if group_name not in source:
311343
source[group_name] = {}
312344
if agent_name not in source[group_name]:
@@ -427,13 +459,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
427459

428460
for agent_id, group_id in zip(steps.agent_id, steps.group_id):
429461
agent_name = f"agent_{agent_id}"
430-
group_name = f"group_{group_id}"
431-
432-
self._check_agent_exists(agent_name, group_name)
462+
self._check_agent_exists(agent_name, group_id)
463+
group_name = self.agent_name_to_group_name_map[agent_name]
433464

434465
agent_action_spec = self.full_action_spec[group_name, agent_name]
435466
action_tuple = self.lib.base_env.ActionTuple()
436-
agent_id = self.agent_name_to_id_map[agent_name]
437467
discrete_branches = env_action_spec.discrete_branches
438468
continuous_size = env_action_spec.continuous_size
439469

@@ -511,6 +541,11 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper):
511541
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
512542
for envs to be ``done`` just after :meth:`~.reset` is called.
513543
Defaults to ``False``.
544+
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
545+
group agents in tensordicts for input/output. See
546+
:class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
547+
specified, agents are grouped according to the group ID given by the
548+
Unity environment. Defaults to ``None``.
514549
categorical_actions (bool, optional): if ``True``, categorical specs
515550
will be converted to the TorchRL equivalent
516551
(:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
@@ -804,12 +839,14 @@ def __init__(
804839
file_name: Optional[str] = None,
805840
registered_name: Optional[str] = None,
806841
*,
842+
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
807843
categorical_actions=False,
808844
**kwargs,
809845
):
810846
kwargs["file_name"] = file_name
811847
kwargs["registered_name"] = registered_name
812848
super().__init__(
849+
group_map=group_map,
813850
categorical_actions=categorical_actions,
814851
**kwargs,
815852
)

0 commit comments

Comments
 (0)