66from __future__ import annotations
77
88import importlib .util
9- from typing import Dict , Optional
9+ from typing import Dict , List , Optional
1010
1111import torch
1212from tensordict import TensorDict , TensorDictBase
2020 Unbounded ,
2121)
2222from torchrl .envs .common import _EnvWrapper
23- from torchrl .envs .utils import _classproperty , check_marl_grouping
23+ from torchrl .envs .utils import _classproperty , check_marl_grouping , MarlGroupMapType
2424
2525_has_unity_mlagents = importlib .util .find_spec ("mlagents_envs" ) is not None
2626
@@ -56,6 +56,11 @@ class UnityMLAgentsWrapper(_EnvWrapper):
5656 allow_done_after_reset (bool, optional): if ``True``, it is tolerated
5757 for envs to be ``done`` just after :meth:`~.reset` is called.
5858 Defaults to ``False``.
59+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
60+ group agents in tensordicts for input/output. See
61+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
62+ specified, agents are grouped according to the group ID given by the
63+ Unity environment. Defaults to ``None``.
5964 categorical_actions (bool, optional): if ``True``, categorical specs
6065 will be converted to the TorchRL equivalent
6166 (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
@@ -92,12 +97,14 @@ def __init__(
9297 self ,
9398 env = None ,
9499 * ,
100+ group_map : MarlGroupMapType | Dict [str , List [str ]] | None = None ,
95101 categorical_actions : bool = False ,
96102 ** kwargs ,
97103 ):
98104 if env is not None :
99105 kwargs ["env" ] = env
100106
107+ self .group_map = group_map
101108 self .categorical_actions = categorical_actions
102109 super ().__init__ (** kwargs )
103110
@@ -118,12 +125,11 @@ def _build_env(self, env, requires_grad: bool = False, **kwargs):
118125 def _init_env (self ):
119126 self ._update_action_mask ()
120127
121- # Creates a group map where agents are grouped by their group_id.
122- def _make_group_map (self , env ):
123- group_map = {}
124- agent_names = []
128+ # Creates a group map where agents are grouped by the group_id given by the
129+ # Unity environment.
130+ def _collect_agents (self , env ):
125131 agent_name_to_behavior_map = {}
126- agent_name_to_id_map = {}
132+ agent_name_to_group_id_map = {}
127133
128134 for steps_idx in [0 , 1 ]:
129135 for behavior in env .behavior_specs .keys ():
@@ -134,22 +140,41 @@ def _make_group_map(self, env):
134140
135141 for agent_id , group_id in zip (agent_ids , group_ids ):
136142 agent_name = f"agent_{ agent_id } "
137- group_name = f"group_{ group_id } "
138- if group_name not in group_map .keys ():
139- group_map [group_name ] = []
140- if agent_name in group_map [group_name ]:
143+ if agent_name in agent_name_to_behavior_map :
141144 # Sometimes in an MLAgents environment, an agent may
142145 # show up in both the decision steps and the terminal
143146 # steps. When that happens, just skip the duplicate.
144147 assert is_terminal
145148 continue
146- group_map [group_name ].append (agent_name )
147- agent_names .append (agent_name )
148149 agent_name_to_behavior_map [agent_name ] = behavior
149- agent_name_to_id_map [agent_name ] = agent_id
150+ agent_name_to_group_id_map [agent_name ] = group_id
151+
152+ return (
153+ agent_name_to_behavior_map ,
154+ agent_name_to_group_id_map ,
155+ )
150156
151- check_marl_grouping (group_map , agent_names )
152- return group_map , agent_name_to_behavior_map , agent_name_to_id_map
157+ # Creates a group map where agents are grouped by their group_id.
158+ def _make_default_group_map (self , agent_name_to_group_id_map ):
159+ group_map = {}
160+ for agent_name , group_id in agent_name_to_group_id_map .items ():
161+ group_name = f"group_{ group_id } "
162+ if group_name not in group_map :
163+ group_map [group_name ] = []
164+ group_map [group_name ].append (agent_name )
165+ return group_map
166+
167+ def _make_group_map (self , group_map , agent_name_to_group_id_map ):
168+ if group_map is None :
169+ group_map = self ._make_default_group_map (agent_name_to_group_id_map )
170+ elif isinstance (group_map , MarlGroupMapType ):
171+ group_map = group_map .get_group_map (agent_name_to_group_id_map .keys ())
172+ check_marl_grouping (group_map , agent_name_to_group_id_map .keys ())
173+ agent_name_to_group_name_map = {}
174+ for group_name , agents in group_map .items ():
175+ for agent_name in agents :
176+ agent_name_to_group_name_map [agent_name ] = group_name
177+ return group_map , agent_name_to_group_name_map
153178
154179 def _make_specs (
155180 self , env : "mlagents_envs.environment.UnityEnvironment" # noqa: F821
@@ -163,10 +188,13 @@ def _make_specs(
163188 # will need to detect changes to the behaviors and agents on each step.
164189 env .reset ()
165190 (
166- self .group_map ,
167191 self .agent_name_to_behavior_map ,
168- self .agent_name_to_id_map ,
169- ) = self ._make_group_map (env )
192+ self .agent_name_to_group_id_map ,
193+ ) = self ._collect_agents (env )
194+
195+ (self .group_map , self .agent_name_to_group_name_map ) = self ._make_group_map (
196+ self .group_map , self .agent_name_to_group_id_map
197+ )
170198
171199 action_spec = {}
172200 observation_spec = {}
@@ -257,17 +285,21 @@ def _set_seed(self, seed):
257285 if seed is not None :
258286 raise NotImplementedError ("This environment has no seed." )
259287
260- def _check_agent_exists (self , agent_name , group_name ):
261- if (
262- group_name not in self .full_action_spec .keys ()
263- or agent_name not in self .full_action_spec [group_name ].keys ()
264- ):
288+ def _check_agent_exists (self , agent_name , group_id ):
289+ if agent_name not in self .agent_name_to_group_id_map :
265290 raise RuntimeError (
266291 (
267292 "Unity environment added a new agent. This is not yet "
268293 "supported in torchrl."
269294 )
270295 )
296+ if self .agent_name_to_group_id_map [agent_name ] != group_id :
297+ raise RuntimeError (
298+ (
299+ "Unity environment changed the group of an agent. This "
300+ "is not yet supported in torchrl."
301+ )
302+ )
271303
272304 def _update_action_mask (self ):
273305 for behavior , behavior_spec in self ._env .behavior_specs .items ():
@@ -290,8 +322,8 @@ def _update_action_mask(self):
290322 steps .agent_id , steps .group_id , combined_action_mask
291323 ):
292324 agent_name = f"agent_{ agent_id } "
293- group_name = f"group_ { group_id } "
294- self ._check_agent_exists ( agent_name , group_name )
325+ self . _check_agent_exists ( agent_name , group_id )
326+ group_name = self .agent_name_to_group_name_map [ agent_name ]
295327 self .full_action_spec [
296328 group_name , agent_name , "discrete_action"
297329 ].update_mask (agent_action_mask )
@@ -305,8 +337,8 @@ def _make_td_out(self, tensordict_in, is_reset=False):
305337 zip (steps .agent_id , steps .group_id )
306338 ):
307339 agent_name = f"agent_{ agent_id } "
308- group_name = f"group_ { group_id } "
309- self ._check_agent_exists ( agent_name , group_name )
340+ self . _check_agent_exists ( agent_name , group_id )
341+ group_name = self .agent_name_to_group_name_map [ agent_name ]
310342 if group_name not in source :
311343 source [group_name ] = {}
312344 if agent_name not in source [group_name ]:
@@ -427,13 +459,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
427459
428460 for agent_id , group_id in zip (steps .agent_id , steps .group_id ):
429461 agent_name = f"agent_{ agent_id } "
430- group_name = f"group_{ group_id } "
431-
432- self ._check_agent_exists (agent_name , group_name )
462+ self ._check_agent_exists (agent_name , group_id )
463+ group_name = self .agent_name_to_group_name_map [agent_name ]
433464
434465 agent_action_spec = self .full_action_spec [group_name , agent_name ]
435466 action_tuple = self .lib .base_env .ActionTuple ()
436- agent_id = self .agent_name_to_id_map [agent_name ]
437467 discrete_branches = env_action_spec .discrete_branches
438468 continuous_size = env_action_spec .continuous_size
439469
@@ -511,6 +541,11 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper):
511541 allow_done_after_reset (bool, optional): if ``True``, it is tolerated
512542 for envs to be ``done`` just after :meth:`~.reset` is called.
513543 Defaults to ``False``.
544+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
545+ group agents in tensordicts for input/output. See
546+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
547+ specified, agents are grouped according to the group ID given by the
548+ Unity environment. Defaults to ``None``.
514549 categorical_actions (bool, optional): if ``True``, categorical specs
515550 will be converted to the TorchRL equivalent
516551 (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
@@ -804,12 +839,14 @@ def __init__(
804839 file_name : Optional [str ] = None ,
805840 registered_name : Optional [str ] = None ,
806841 * ,
842+ group_map : MarlGroupMapType | Dict [str , List [str ]] | None = None ,
807843 categorical_actions = False ,
808844 ** kwargs ,
809845 ):
810846 kwargs ["file_name" ] = file_name
811847 kwargs ["registered_name" ] = registered_name
812848 super ().__init__ (
849+ group_map = group_map ,
813850 categorical_actions = categorical_actions ,
814851 ** kwargs ,
815852 )
0 commit comments