From c57d1d913eb5fecb66c7941b4451cd23ab38d1ee Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 11 Oct 2024 09:58:03 -0700 Subject: [PATCH] Add `group_map` support to MLAgents wrappers --- test/test_libs.py | 31 +++++++-- torchrl/envs/libs/unity_mlagents.py | 101 +++++++++++++++++++--------- 2 files changed, 93 insertions(+), 39 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a165c6916fb..3d04648fd4e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3950,13 +3950,17 @@ def test_chance_not_implemented(self): class TestUnityMLAgents: @mock.patch("mlagents_envs.env_utils.launch_executable") @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") - def test_env(self, mock_communicator, mock_launcher): + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_env(self, mock_communicator, mock_launcher, group_map): from mlagents_envs.mock_communicator import MockCommunicator mock_communicator.return_value = MockCommunicator( discrete_action=False, visual_inputs=0 ) - env = UnityMLAgentsEnv(" ") + env = UnityMLAgentsEnv(" ", group_map=group_map) try: check_env_specs(env) finally: @@ -3964,14 +3968,18 @@ def test_env(self, mock_communicator, mock_launcher): @mock.patch("mlagents_envs.env_utils.launch_executable") @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") - def test_wrapper(self, mock_communicator, mock_launcher): + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_wrapper(self, mock_communicator, mock_launcher, group_map): from mlagents_envs.environment import UnityEnvironment from mlagents_envs.mock_communicator import MockCommunicator mock_communicator.return_value = MockCommunicator( discrete_action=False, visual_inputs=0 ) - env = UnityMLAgentsWrapper(UnityEnvironment(" ")) + env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map) try: check_env_specs(env) finally: @@ -3979,14 +3987,18 @@ def test_wrapper(self, mock_communicator, mock_launcher): @mock.patch("mlagents_envs.env_utils.launch_executable") @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") - def test_rollout(self, mock_communicator, mock_launcher): + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_rollout(self, mock_communicator, mock_launcher, group_map): from mlagents_envs.environment import UnityEnvironment from mlagents_envs.mock_communicator import MockCommunicator mock_communicator.return_value = MockCommunicator( discrete_action=False, visual_inputs=0 ) - env = UnityMLAgentsWrapper(UnityEnvironment(" ")) + env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map) try: env.rollout( max_steps=500, break_when_any_done=False, break_when_all_done=False @@ -4031,10 +4043,15 @@ def test_with_editor(self): 5, ) @pytest.mark.parametrize("registered_name", _mlagents_registered_envs) - def test_registered_envs(self, registered_name): + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_registered_envs(self, registered_name, group_map): env = UnityMLAgentsEnv( registered_name=registered_name, no_graphics=True, + group_map=group_map, ) try: check_env_specs(env) diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 6ed019c2332..95c2460bc83 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -6,7 +6,7 @@ from __future__ import annotations import importlib.util -from typing import Dict, Optional +from typing import Dict, List, Optional import torch from tensordict import TensorDict, TensorDictBase @@ -20,7 +20,7 @@ Unbounded, ) from torchrl.envs.common import _EnvWrapper -from torchrl.envs.utils import _classproperty, check_marl_grouping +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType _has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None @@ -56,6 +56,11 @@ class UnityMLAgentsWrapper(_EnvWrapper): allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to + group agents in tensordicts for input/output. See + :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not + specified, agents are grouped according to the group ID given by the + Unity environment. Defaults to ``None``. categorical_actions (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding @@ -92,12 +97,14 @@ def __init__( self, env=None, *, + group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, categorical_actions: bool = False, **kwargs, ): if env is not None: kwargs["env"] = env + self.group_map = group_map self.categorical_actions = categorical_actions super().__init__(**kwargs) @@ -118,12 +125,11 @@ def _build_env(self, env, requires_grad: bool = False, **kwargs): def _init_env(self): self._update_action_mask() - # Creates a group map where agents are grouped by their group_id. - def _make_group_map(self, env): - group_map = {} - agent_names = [] + # Creates a group map where agents are grouped by the group_id given by the + # Unity environment. + def _collect_agents(self, env): agent_name_to_behavior_map = {} - agent_name_to_id_map = {} + agent_name_to_group_id_map = {} for steps_idx in [0, 1]: for behavior in env.behavior_specs.keys(): @@ -134,22 +140,41 @@ def _make_group_map(self, env): for agent_id, group_id in zip(agent_ids, group_ids): agent_name = f"agent_{agent_id}" - group_name = f"group_{group_id}" - if group_name not in group_map.keys(): - group_map[group_name] = [] - if agent_name in group_map[group_name]: + if agent_name in agent_name_to_behavior_map: # Sometimes in an MLAgents environment, an agent may # show up in both the decision steps and the terminal # steps. When that happens, just skip the duplicate. assert is_terminal continue - group_map[group_name].append(agent_name) - agent_names.append(agent_name) agent_name_to_behavior_map[agent_name] = behavior - agent_name_to_id_map[agent_name] = agent_id + agent_name_to_group_id_map[agent_name] = group_id + + return ( + agent_name_to_behavior_map, + agent_name_to_group_id_map, + ) - check_marl_grouping(group_map, agent_names) - return group_map, agent_name_to_behavior_map, agent_name_to_id_map + # Creates a group map where agents are grouped by their group_id. + def _make_default_group_map(self, agent_name_to_group_id_map): + group_map = {} + for agent_name, group_id in agent_name_to_group_id_map.items(): + group_name = f"group_{group_id}" + if group_name not in group_map: + group_map[group_name] = [] + group_map[group_name].append(agent_name) + return group_map + + def _make_group_map(self, group_map, agent_name_to_group_id_map): + if group_map is None: + group_map = self._make_default_group_map(agent_name_to_group_id_map) + elif isinstance(group_map, MarlGroupMapType): + group_map = group_map.get_group_map(agent_name_to_group_id_map.keys()) + check_marl_grouping(group_map, agent_name_to_group_id_map.keys()) + agent_name_to_group_name_map = {} + for group_name, agents in group_map.items(): + for agent_name in agents: + agent_name_to_group_name_map[agent_name] = group_name + return group_map, agent_name_to_group_name_map def _make_specs( self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821 @@ -163,10 +188,13 @@ def _make_specs( # will need to detect changes to the behaviors and agents on each step. env.reset() ( - self.group_map, self.agent_name_to_behavior_map, - self.agent_name_to_id_map, - ) = self._make_group_map(env) + self.agent_name_to_group_id_map, + ) = self._collect_agents(env) + + (self.group_map, self.agent_name_to_group_name_map) = self._make_group_map( + self.group_map, self.agent_name_to_group_id_map + ) action_spec = {} observation_spec = {} @@ -257,17 +285,21 @@ def _set_seed(self, seed): if seed is not None: raise NotImplementedError("This environment has no seed.") - def _check_agent_exists(self, agent_name, group_name): - if ( - group_name not in self.full_action_spec.keys() - or agent_name not in self.full_action_spec[group_name].keys() - ): + def _check_agent_exists(self, agent_name, group_id): + if agent_name not in self.agent_name_to_group_id_map: raise RuntimeError( ( "Unity environment added a new agent. This is not yet " "supported in torchrl." ) ) + if self.agent_name_to_group_id_map[agent_name] != group_id: + raise RuntimeError( + ( + "Unity environment changed the group of an agent. This " + "is not yet supported in torchrl." + ) + ) def _update_action_mask(self): for behavior, behavior_spec in self._env.behavior_specs.items(): @@ -290,8 +322,8 @@ def _update_action_mask(self): steps.agent_id, steps.group_id, combined_action_mask ): agent_name = f"agent_{agent_id}" - group_name = f"group_{group_id}" - self._check_agent_exists(agent_name, group_name) + self._check_agent_exists(agent_name, group_id) + group_name = self.agent_name_to_group_name_map[agent_name] self.full_action_spec[ group_name, agent_name, "discrete_action" ].update_mask(agent_action_mask) @@ -305,8 +337,8 @@ def _make_td_out(self, tensordict_in, is_reset=False): zip(steps.agent_id, steps.group_id) ): agent_name = f"agent_{agent_id}" - group_name = f"group_{group_id}" - self._check_agent_exists(agent_name, group_name) + self._check_agent_exists(agent_name, group_id) + group_name = self.agent_name_to_group_name_map[agent_name] if group_name not in source: source[group_name] = {} if agent_name not in source[group_name]: @@ -427,13 +459,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for agent_id, group_id in zip(steps.agent_id, steps.group_id): agent_name = f"agent_{agent_id}" - group_name = f"group_{group_id}" - - self._check_agent_exists(agent_name, group_name) + self._check_agent_exists(agent_name, group_id) + group_name = self.agent_name_to_group_name_map[agent_name] agent_action_spec = self.full_action_spec[group_name, agent_name] action_tuple = self.lib.base_env.ActionTuple() - agent_id = self.agent_name_to_id_map[agent_name] discrete_branches = env_action_spec.discrete_branches continuous_size = env_action_spec.continuous_size @@ -511,6 +541,11 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper): allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to + group agents in tensordicts for input/output. See + :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not + specified, agents are grouped according to the group ID given by the + Unity environment. Defaults to ``None``. categorical_actions (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding @@ -804,12 +839,14 @@ def __init__( file_name: Optional[str] = None, registered_name: Optional[str] = None, *, + group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, categorical_actions=False, **kwargs, ): kwargs["file_name"] = file_name kwargs["registered_name"] = registered_name super().__init__( + group_map=group_map, categorical_actions=categorical_actions, **kwargs, )