Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3950,43 +3950,55 @@ def test_chance_not_implemented(self):
class TestUnityMLAgents:
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_env(self, mock_communicator, mock_launcher):
@pytest.mark.parametrize(
"group_map",
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
)
def test_env(self, mock_communicator, mock_launcher, group_map):
from mlagents_envs.mock_communicator import MockCommunicator

mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityMLAgentsEnv(" ")
env = UnityMLAgentsEnv(" ", group_map=group_map)
try:
check_env_specs(env)
finally:
env.close()

@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_wrapper(self, mock_communicator, mock_launcher):
@pytest.mark.parametrize(
"group_map",
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
)
def test_wrapper(self, mock_communicator, mock_launcher, group_map):
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.mock_communicator import MockCommunicator

mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityMLAgentsWrapper(UnityEnvironment(" "))
env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map)
try:
check_env_specs(env)
finally:
env.close()

@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_rollout(self, mock_communicator, mock_launcher):
@pytest.mark.parametrize(
"group_map",
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
)
def test_rollout(self, mock_communicator, mock_launcher, group_map):
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.mock_communicator import MockCommunicator

mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityMLAgentsWrapper(UnityEnvironment(" "))
env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map)
try:
env.rollout(
max_steps=500, break_when_any_done=False, break_when_all_done=False
Expand Down Expand Up @@ -4031,10 +4043,15 @@ def test_with_editor(self):
5,
)
@pytest.mark.parametrize("registered_name", _mlagents_registered_envs)
def test_registered_envs(self, registered_name):
@pytest.mark.parametrize(
"group_map",
[None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
)
def test_registered_envs(self, registered_name, group_map):
env = UnityMLAgentsEnv(
registered_name=registered_name,
no_graphics=True,
group_map=group_map,
)
try:
check_env_specs(env)
Expand Down
101 changes: 69 additions & 32 deletions torchrl/envs/libs/unity_mlagents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import importlib.util
from typing import Dict, Optional
from typing import Dict, List, Optional

import torch
from tensordict import TensorDict, TensorDictBase
Expand All @@ -20,7 +20,7 @@
Unbounded,
)
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty, check_marl_grouping
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType

_has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None

Expand Down Expand Up @@ -56,6 +56,11 @@ class UnityMLAgentsWrapper(_EnvWrapper):
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
for envs to be ``done`` just after :meth:`~.reset` is called.
Defaults to ``False``.
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
group agents in tensordicts for input/output. See
:class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
specified, agents are grouped according to the group ID given by the
Unity environment. Defaults to ``None``.
categorical_actions (bool, optional): if ``True``, categorical specs
will be converted to the TorchRL equivalent
(:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
Expand Down Expand Up @@ -92,12 +97,14 @@ def __init__(
self,
env=None,
*,
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
categorical_actions: bool = False,
**kwargs,
):
if env is not None:
kwargs["env"] = env

self.group_map = group_map
self.categorical_actions = categorical_actions
super().__init__(**kwargs)

Expand All @@ -118,12 +125,11 @@ def _build_env(self, env, requires_grad: bool = False, **kwargs):
def _init_env(self):
self._update_action_mask()

# Creates a group map where agents are grouped by their group_id.
def _make_group_map(self, env):
group_map = {}
agent_names = []
# Creates a group map where agents are grouped by the group_id given by the
# Unity environment.
def _collect_agents(self, env):
agent_name_to_behavior_map = {}
agent_name_to_id_map = {}
agent_name_to_group_id_map = {}

for steps_idx in [0, 1]:
for behavior in env.behavior_specs.keys():
Expand All @@ -134,22 +140,41 @@ def _make_group_map(self, env):

for agent_id, group_id in zip(agent_ids, group_ids):
agent_name = f"agent_{agent_id}"
group_name = f"group_{group_id}"
if group_name not in group_map.keys():
group_map[group_name] = []
if agent_name in group_map[group_name]:
if agent_name in agent_name_to_behavior_map:
# Sometimes in an MLAgents environment, an agent may
# show up in both the decision steps and the terminal
# steps. When that happens, just skip the duplicate.
assert is_terminal
continue
group_map[group_name].append(agent_name)
agent_names.append(agent_name)
agent_name_to_behavior_map[agent_name] = behavior
agent_name_to_id_map[agent_name] = agent_id
agent_name_to_group_id_map[agent_name] = group_id

return (
agent_name_to_behavior_map,
agent_name_to_group_id_map,
)

check_marl_grouping(group_map, agent_names)
return group_map, agent_name_to_behavior_map, agent_name_to_id_map
# Creates a group map where agents are grouped by their group_id.
def _make_default_group_map(self, agent_name_to_group_id_map):
group_map = {}
for agent_name, group_id in agent_name_to_group_id_map.items():
group_name = f"group_{group_id}"
if group_name not in group_map:
group_map[group_name] = []
group_map[group_name].append(agent_name)
return group_map

def _make_group_map(self, group_map, agent_name_to_group_id_map):
if group_map is None:
group_map = self._make_default_group_map(agent_name_to_group_id_map)
elif isinstance(group_map, MarlGroupMapType):
group_map = group_map.get_group_map(agent_name_to_group_id_map.keys())
check_marl_grouping(group_map, agent_name_to_group_id_map.keys())
agent_name_to_group_name_map = {}
for group_name, agents in group_map.items():
for agent_name in agents:
agent_name_to_group_name_map[agent_name] = group_name
return group_map, agent_name_to_group_name_map

def _make_specs(
self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821
Expand All @@ -163,10 +188,13 @@ def _make_specs(
# will need to detect changes to the behaviors and agents on each step.
env.reset()
(
self.group_map,
self.agent_name_to_behavior_map,
self.agent_name_to_id_map,
) = self._make_group_map(env)
self.agent_name_to_group_id_map,
) = self._collect_agents(env)

(self.group_map, self.agent_name_to_group_name_map) = self._make_group_map(
self.group_map, self.agent_name_to_group_id_map
)

action_spec = {}
observation_spec = {}
Expand Down Expand Up @@ -257,17 +285,21 @@ def _set_seed(self, seed):
if seed is not None:
raise NotImplementedError("This environment has no seed.")

def _check_agent_exists(self, agent_name, group_name):
if (
group_name not in self.full_action_spec.keys()
or agent_name not in self.full_action_spec[group_name].keys()
):
def _check_agent_exists(self, agent_name, group_id):
if agent_name not in self.agent_name_to_group_id_map:
raise RuntimeError(
(
"Unity environment added a new agent. This is not yet "
"supported in torchrl."
)
)
if self.agent_name_to_group_id_map[agent_name] != group_id:
raise RuntimeError(
(
"Unity environment changed the group of an agent. This "
"is not yet supported in torchrl."
)
)

def _update_action_mask(self):
for behavior, behavior_spec in self._env.behavior_specs.items():
Expand All @@ -290,8 +322,8 @@ def _update_action_mask(self):
steps.agent_id, steps.group_id, combined_action_mask
):
agent_name = f"agent_{agent_id}"
group_name = f"group_{group_id}"
self._check_agent_exists(agent_name, group_name)
self._check_agent_exists(agent_name, group_id)
group_name = self.agent_name_to_group_name_map[agent_name]
self.full_action_spec[
group_name, agent_name, "discrete_action"
].update_mask(agent_action_mask)
Expand All @@ -305,8 +337,8 @@ def _make_td_out(self, tensordict_in, is_reset=False):
zip(steps.agent_id, steps.group_id)
):
agent_name = f"agent_{agent_id}"
group_name = f"group_{group_id}"
self._check_agent_exists(agent_name, group_name)
self._check_agent_exists(agent_name, group_id)
group_name = self.agent_name_to_group_name_map[agent_name]
if group_name not in source:
source[group_name] = {}
if agent_name not in source[group_name]:
Expand Down Expand Up @@ -427,13 +459,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

for agent_id, group_id in zip(steps.agent_id, steps.group_id):
agent_name = f"agent_{agent_id}"
group_name = f"group_{group_id}"

self._check_agent_exists(agent_name, group_name)
self._check_agent_exists(agent_name, group_id)
group_name = self.agent_name_to_group_name_map[agent_name]

agent_action_spec = self.full_action_spec[group_name, agent_name]
action_tuple = self.lib.base_env.ActionTuple()
agent_id = self.agent_name_to_id_map[agent_name]
discrete_branches = env_action_spec.discrete_branches
continuous_size = env_action_spec.continuous_size

Expand Down Expand Up @@ -511,6 +541,11 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper):
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
for envs to be ``done`` just after :meth:`~.reset` is called.
Defaults to ``False``.
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
group agents in tensordicts for input/output. See
:class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
specified, agents are grouped according to the group ID given by the
Unity environment. Defaults to ``None``.
categorical_actions (bool, optional): if ``True``, categorical specs
will be converted to the TorchRL equivalent
(:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
Expand Down Expand Up @@ -804,12 +839,14 @@ def __init__(
file_name: Optional[str] = None,
registered_name: Optional[str] = None,
*,
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
categorical_actions=False,
**kwargs,
):
kwargs["file_name"] = file_name
kwargs["registered_name"] = registered_name
super().__init__(
group_map=group_map,
categorical_actions=categorical_actions,
**kwargs,
)
Expand Down
Loading