From 7d291a73e1b2d1fc4799e222d845862d7f6876bb Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 27 Jul 2023 16:00:59 +0100 Subject: [PATCH 01/35] init --- docs/source/reference/envs.rst | 1 + test/test_specs.py | 61 ++++++++++ torchrl/data/tensor_specs.py | 141 ++++++++++++++++++---- torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 122 +++++++++++++++++++ torchrl/modules/distributions/discrete.py | 7 +- 7 files changed, 307 insertions(+), 27 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 50519dc85fa..d42a735a208 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -445,6 +445,7 @@ to be able to create this other composition: Transform TransformedEnv + ActionMask BinarizeReward CatFrames CatTensors diff --git a/test/test_specs.py b/test/test_specs.py index 10adac74bdc..f0ef4a66e3b 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -2648,6 +2648,67 @@ def test_composite_contains(): assert ("a", ("b", ("c",))) in spec.keys(True, True) +@pytest.mark.parametrize("shape", ((), (1,), (2, 3), (2, 3, 4))) +@pytest.mark.parametrize("one_hot", [True, False]) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("rand_shape", ((), (2,), (2, 3))) +class TestSpecMasking: + def _make_mask(self, shape): + torch.manual_seed(0) + mask = torch.zeros(shape, dtype=torch.bool).bernoulli_() + if len(shape) == 1: + while not mask.any() or mask.all(): + mask = torch.zeros(shape, dtype=torch.bool).bernoulli_() + return mask + mask_view = mask.view(-1, shape[-1]) + for i in range(mask_view.shape[0]): + t = mask_view[i] + while not t.any() or t.all(): + t.copy_(torch.zeros_like(t).bernoulli_()) + return mask + + def _one_hot_spec(self, shape, device, n): + shape = torch.Size([*shape, n]) + mask = self._make_mask(shape).to(device) + return OneHotDiscreteTensorSpec(n, shape, device, mask=mask) + + def _discrete_spec(self, shape, device, n): + mask = self._make_mask(torch.Size([*shape, n])).to(device) + return DiscreteTensorSpec(n, shape, device, mask=mask) + + def test_is_in(self, shape, device, one_hot, rand_shape, n=5): + shape = torch.Size(shape) + rand_shape = torch.Size(rand_shape) + spec = ( + self._one_hot_spec(shape, device, n=n) + if one_hot + else self._discrete_spec(shape, device, n=n) + ) + s = spec.rand(rand_shape) + assert spec.is_in(s) + spec.update_mask(~spec.mask) + assert not spec.is_in(s) + + def test_project(self, shape, device, one_hot, rand_shape, n=5): + shape = torch.Size(shape) + rand_shape = torch.Size(rand_shape) + spec = ( + self._one_hot_spec(shape, device, n=n) + if one_hot + else self._discrete_spec(shape, device, n=n) + ) + s = spec.rand(rand_shape) + assert (spec.project(s) == s).all() + spec.update_mask(~spec.mask) + sp = spec.project(s) + assert sp.shape == s.shape + if one_hot: + assert (sp != s).any(-1).all() + assert (sp.any(-1)).all() + else: + assert (sp != s).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4d69949b964..57117bea44f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1030,6 +1030,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, use_register: bool = False, + mask: torch.Tensor | None = None, ): dtype, device = _default_dtype_and_device(dtype, device) @@ -1045,6 +1046,17 @@ def __init__( f"Got n={space.n} and shape={shape}." ) super().__init__(shape, space, device, dtype, "discrete") + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(self.shape) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): @@ -1084,8 +1096,15 @@ def expand(self, *shape): f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) + mask = self.mask + if mask is not None: + mask = mask.expand(shape) return self.__class__( - n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def squeeze(self, dim=None): @@ -1097,13 +1116,16 @@ def squeeze(self, dim=None): shape = _squeezed_shape(self.shape, dim) if shape is None: return self - + mask = self.mask + if mask is not None: + mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=mask, ) def unsqueeze(self, dim: int): @@ -1113,12 +1135,16 @@ def unsqueeze(self, dim: int): ) shape = _unsqueezed_shape(self.shape, dim) + mask = self.mask + if mask is not None: + mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=mask, ) def rand(self, shape=None) -> torch.Tensor: @@ -1126,9 +1152,19 @@ def rand(self, shape=None) -> torch.Tensor: shape = self.shape[:-1] else: shape = torch.Size([*shape, *self.shape[:-1]]) - n = self.space.n - m = torch.randint(n, (*shape, 1), device=self.device) - out = torch.zeros((*shape, n), device=self.device, dtype=self.dtype) + mask = self.mask + if mask is None: + n = self.space.n + m = torch.randint(n, (*shape, 1), device=self.device) + else: + mask = mask.expand(*shape, mask.shape[-1]) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + m = torch.multinomial(mask_flat.float(), 1).reshape(*shape_out, 1) + out = torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype) out.scatter_(-1, m, 1) return out @@ -1200,13 +1236,29 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) def _project(self, val: torch.Tensor) -> torch.Tensor: - # idx = val.sum(-1) != 1 - out = torch.nn.functional.gumbel_softmax(val.to(torch.float)) - out = (out == out.max(dim=-1, keepdim=True)[0]).to(torch.long) - return out + if self.mask is None: + out = torch.multinomial(val.to(torch.float), 1) + out = (out == out.max(dim=-1, keepdim=True)[0]).to(self.dtype) + return out + shape = self.mask.shape + shape = torch.broadcast_shapes(shape, val.shape) + mask_expand = self.mask.expand(shape) + gathered = mask_expand & val + oob = ~gathered.any(-1) + new_val = torch.multinomial(mask_expand[oob].float(), 1) + val = val.clone() + val[oob] = 0 + val[oob] = torch.scatter(val[oob], -1, new_val, 1) + return val def is_in(self, val: torch.Tensor) -> bool: - return (val.sum(-1) == 1).all() + if self.mask is None: + return (val.sum(-1) == 1).all() + shape = self.mask.shape + shape = torch.broadcast_shapes(shape, val.shape) + mask_expand = self.mask.expand(shape) + gathered = mask_expand & val + return gathered.any(-1).all() def __eq__(self, other): return ( @@ -1956,34 +2008,73 @@ class DiscreteTensorSpec(TensorSpec): def __init__( self, n: int, - shape: Optional[torch.Size] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, + shape: torch.Size | None = None, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype = torch.long, + mask: torch.Tensor | None = None, ): if shape is None: shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) super().__init__(shape, space, device, dtype, domain="discrete") + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(*self.shape, self.space.n) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - return torch.randint( - 0, - self.space.n, - torch.Size([*shape, *self.shape]), - device=self.device, - dtype=self.dtype, - ) + if self.mask is None: + return torch.randint( + 0, + self.space.n, + torch.Size([*shape, *self.shape]), + device=self.device, + dtype=self.dtype, + ) + mask = self.mask + mask = mask.expand(*shape, *mask.shape) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + print("mask in spec", mask_flat) + out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + print("out", out) + return out def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (torch.int, torch.long): val = torch.round(val) - return val.clamp_(min=0, max=self.space.n - 1) + if self.mask is None: + return val.clamp_(min=0, max=self.space.n - 1) + shape = self.mask.shape + shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + mask_expand = self.mask.expand(shape) + gathered = mask_expand.gather(-1, val.unsqueeze(-1)) + oob = ~gathered.all(-1) + new_val = torch.multinomial(mask_expand[oob].float(), 1).squeeze(-1) + val = torch.masked_scatter(val, oob, new_val) + return val def is_in(self, val: torch.Tensor) -> bool: - return (0 <= val).all() and (val < self.space.n).all() + if self.mask is None: + return (0 <= val).all() and (val < self.space.n).all() + shape = self.mask.shape + shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + mask_expand = self.mask.expand(shape) + gathered = mask_expand.gather(-1, val.unsqueeze(-1)) + return gathered.all() def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" @@ -2030,7 +2121,11 @@ def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" shape = [*self.shape, self.space.n] return OneHotDiscreteTensorSpec( - n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + n=self.space.n, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=self.mask, ) def expand(self, *shape): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 827479bf9c0..25e2f6b09d1 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -8,6 +8,7 @@ from .gym_like import default_info_dict_reader, GymLikeEnv from .model_based import ModelBasedEnvBase from .transforms import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 5ee87e2c0eb..fb0d738ee5d 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -6,6 +6,7 @@ from .r3m import R3MTransform from .rlhf import KLRewardTransform from .transforms import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7156c4cf571..de41b8f85e1 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -370,8 +370,49 @@ def clone(self): self_copy.__dict__.update(state) return self_copy + @property + def container(self): + """Returns the env containing the transform. + + Examples: + >>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter())) + >>> env.transform[0].container is env + True + """ + if "_container" not in self.__dict__: + raise AttributeError("transform parent uninitialized") + container = self.__dict__["_container"] + if container is None: + return container + while not isinstance(container, EnvBase): + # if it's not an env, it should be a Compose transform + if not isinstance(container, Compose): + raise ValueError( + "A transform parent must be either another Compose transform or an environment object." + ) + compose = container + container = compose.__dict__.get("_container", None) + return container + @property def parent(self) -> Optional[EnvBase]: + """Returns the parent env of the transform. + + The parent env is the env that contains all the transforms up until the current one. + + Examples: + >>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter())) + >>> env.transform[1].parent + TransformedEnv( + env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu), + transform=Compose( + RewardSum(keys=['reward']))) + + """ if self.__dict__.get("_parent", None) is None: if "_container" not in self.__dict__: raise AttributeError("transform parent uninitialized") @@ -4353,3 +4394,84 @@ def _inv_apply_transform( def set_container(self, container): if isinstance(container, EnvBase) or container.parent is not None: raise ValueError(self.ENV_ERR) + + +class ActionMask(Transform): + """An adaptive action masker. + + This transform reads the mask from the input tensordict after the step is executed, + and adapts the mask of the one-hot / categorical action spec. + + .. note:: This transform will fail when used without an environment. + + Examples: + >>> import torch + >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec + >>> from torchrl.envs.transforms import ActionMask, TransformedEnv, EnvBase + >>> class MaskedEnv(EnvBase): + ... def __init__(self, *args, **kwargs): + ... super().__init__(*args, **kwargs) + ... self.action_spec = DiscreteTensorSpec(4) + ... self.state_spec = CompositeSpec(mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)) + ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) + ... self.reward_spec = UnboundedContinuousTensorSpec(1) + ... + ... def _reset(self, data): + ... td = self.observation_spec.rand() + ... td.update(torch.ones_like(self.state_spec.rand())) + ... return td + ... + ... def _step(self, data): + ... td = self.observation_spec.rand() + ... mask = data.get("mask") + ... action = data.get("action") + ... mask = mask.scatter(-1, action.unsqueeze(-1), 0) + ... + ... td.set("mask", mask) + ... td.set("reward", self.reward_spec.rand()) + ... td.set("done", ~mask.any().view(1)) + ... return td.empty().set("next", td) + ... + ... def _set_seed(self, seed): + ... return seed + >>> base_env = MaskedEnv() + >>> env = TransformedEnv(base_env, ActionMask()) + >>> r = env.rollout(10) + >>> env = TransformedEnv(base_env, ActionMask()) + >>> r = env.rollout(10) + >>> r["mask"] + + """ + + def __init__(self, action_key="action", mask_key="mask"): + if not isinstance(action_key, str): + raise ValueError( + f"The action key must be a string. Got {type(action_key)} instead." + ) + if not isinstance(mask_key, str): + raise ValueError( + f"The mask key must be a string. Got {type(mask_key)} instead." + ) + super().__init__( + in_keys=[action_key, mask_key], out_keys=[], in_keys_inv=[], out_keys_inv=[] + ) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + mask = tensordict.get(self.in_keys[1]) + parent = self.parent + if parent is None: + raise RuntimeError( + f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment." + ) + action_spec = self._container.action_spec + if not isinstance(action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec)): + raise ValueError( + f"The action spec must be one of (OneHotDiscreteTensorSpec, DiscreteTensorSpec). Got {type(action_spec)} instead." + ) + action_spec.update_mask(mask) + return tensordict + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + action_spec = self._container.action_spec + action_spec.update_mask(tensordict.get(self.in_keys[1], None)) + return action_spec diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 52d52e3113e..9562dab1643 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -201,13 +201,12 @@ def sample( # Python 3.7 doesn't support math.prod # outer_dim = prod(sample_shape) # inner_dim = prod(self._mask.size()[:-1]) - outer_dim = torch.empty(sample_shape, device="meta").numel() - inner_dim = self._mask.numel() // self._mask.size(-1) + outer_dim = sample_shape.numel() + inner_dim = self._mask.shape[:-1].numel() idx_3d = self._mask.expand(outer_dim, inner_dim, -1) ret = idx_3d.gather(dim=-1, index=ret.view(outer_dim, inner_dim, 1)) - return ret.view(size) + return ret.reshape(size) - # # # TODO: Improve performance here. def log_prob(self, value: torch.Tensor) -> torch.Tensor: if not self._sparse_mask: return super().log_prob(value) From 32d7dda6bbb6e62131a1ce448e0900df3224ce08 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 18 Aug 2023 17:24:20 +0100 Subject: [PATCH 02/35] amend Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 torchrl/envs/libs/smacv2.py diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/torchrl/envs/libs/smacv2.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. From 58cbe438c868255cb16bc0ba78cf9c66837b70ff Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 18 Aug 2023 18:09:42 +0100 Subject: [PATCH 03/35] amend Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 274 ++++++++++++++++++++++++++++++++++++ 1 file changed, 274 insertions(+) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 7bec24cb17b..6cc90bbbae1 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -2,3 +2,277 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + +import torch + +from torchrl.data import ( + CompositeSpec, + DiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs.common import _EnvWrapper + +IMPORT_ERR = None +try: + import smacv2 + from smacv2.env.starcraft2.maps import smac_maps + + _has_smacv2 = True +except ImportError as err: + _has_smacv2 = False + IMPORT_ERR = err + + +def _get_envs(): + if not _has_smacv2: + return [] + return smac_maps.get_smac_map_registry().keys() + + +class SMACv2Wrapper(_EnvWrapper): + """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. + + Examples: + >>> env = smac.env.StarCraft2Env("8m") + >>> env = SMACv2Wrapper(env) + >>> td = env.reset() + >>> td["action"] = env.action_spec.rand() + >>> td = env.step(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(torch.Size([8, 14]), dtype=torch.int64), + done: Tensor(torch.Size([1]), dtype=torch.bool), + next: TensorDict( + fields={ + observation: Tensor(torch.Size([8, 80]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(torch.Size([8, 80]), dtype=torch.float32), + reward: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(env.available_envs) + ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] + """ + + git_url = "https://github.com/oxwhirl/smacv2" + libname = "smacv2" + available_envs = _get_envs() + + def __init__( + self, + env: "smacv2.env.StarCraft2Env" = None, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + + super().__init__(**kwargs) + + @property + def lib(self): + return smacv2 + + def _check_kwargs(self, kwargs: Dict): + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, smacv2.env.StarCraft2Env): + raise TypeError("env is not of type 'smacv2.env.StarCraft2Env'.") + + def _build_env( + self, + env: "smacv2.env.StarCraft2Env", + ): + if len(self.batch_size): + raise RuntimeError( + f"SMACv2 does not support custom batch_size {self.batch_size}." + ) + + return env + + def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: + # Extract specs from definition. + self.reward_spec = UnboundedContinuousTensorSpec( + shape=torch.Size((1,)), + device=self.device, + ) + self.done_spec = DiscreteTensorSpec( + n=2, + shape=torch.Size((1,)), + dtype=torch.bool, + device=self.device, + ) + + # Specs that require initialized environment are built in _init_env. + + def _init_env(self) -> None: + self._env.reset() + + # Before extracting environment specific specs, env.reset() must be executed. + self.action_spec = self._make_action_spec() + self.observation_spec = self._make_observation_spec() + + def _make_action_spec(self) -> CompositeSpec: + # TODO masking + # mask = torch.tensor(env.get_avail_actions(), dtype=torch.bool, device=self.device) + action_spec = MultiOneHotDiscreteTensorSpec( + [self.n_actions], + shape=torch.Size([self.n_agents, self.n_actions]), + device=self.device, + ) + spec = CompositeSpec( + { + "agents": CompositeSpec( + {"action": action_spec}, shape=torch.Size((self.n_agents,)) + ) + } + ) + return spec + + def _make_observation_spec(self) -> CompositeSpec: + obs_spec = UnboundedContinuousTensorSpec( + torch.Size([self.n_agents, self.get_obs_size()]), device=self.device + ) + spec = CompositeSpec( + { + "agents": CompositeSpec( + {"observation": obs_spec}, shape=torch.Size((self.n_agents,)) + ), + "state": UnboundedContinuousTensorSpec( + torch.Size([self.n_agents, self.get_state_size()]), + device=self.device, + ), + } + ) + return spec + + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError( + "Seed cannot be changed once environment was created." + ) + + # def _action_transform(self, action: torch.Tensor): + # action_np = self.action_spec.to_numpy(action) + # return action_np + # + # def _read_state(self, state: np.ndarray) -> torch.Tensor: + # return self.state_spec.encode( + # torch.Tensor(state, device=self.device).expand(*self.state_spec.shape) + # ) + # + # + # def _reset( + # self, tensordict: Optional[TensorDictBase] = None, **kwargs + # ) -> TensorDictBase: + # env: smac.env.StarCraft2Env = self._env + # obs, state = env.reset() + # + # # collect outputs + # obs_dict = self.read_obs(obs) + # state = self._read_state(state) + # self._is_done = torch.zeros(self.batch_size, dtype=torch.bool) + # + # # build results + # tensordict_out = TensorDict( + # source=obs_dict, + # batch_size=self.batch_size, + # device=self.device, + # ) + # tensordict_out.set("done", self._is_done) + # tensordict_out["state"] = state + # + # self.input_spec = self._make_input_spec(env) + # + # return tensordict_out + # + # def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # env: smac.env.StarCraft2Env = self._env + # + # # perform actions + # action = tensordict.get("action") # this is a list of actions for each agent + # action_np = self._action_transform(action) + # + # # Actions are validated by the environment. + # reward, done, info = env.step(action_np) + # + # # collect outputs + # obs_dict = self.read_obs(env.get_obs()) + # # TODO: add centralized flag? + # state = self._read_state(env.get_state()) + # + # reward = self._to_tensor(reward, dtype=self.reward_spec.dtype).expand( + # self.batch_size + # ) + # done = self._to_tensor(done, dtype=torch.bool).expand(self.batch_size) + # + # # build results + # tensordict_out = TensorDict( + # source=obs_dict, + # batch_size=tensordict.batch_size, + # device=self.device, + # ) + # tensordict_out.set("reward", reward) + # tensordict_out.set("done", done) + # tensordict_out["state"] = state + # + # # Update available actions mask. + # self.input_spec = self._make_input_spec(env) + # + # return tensordict_out + + +class SMACv2Env(SMACv2Wrapper): + """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. + + Examples: + >>> env = SMACv2Env(map_name="8m") + >>> print(env.available_envs) + ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] + """ + + def __init__( + self, + map_name: str, + capability_config: Optional[Dict] = None, + seed: Optional[int] = None, + **kwargs, + ): + if not _has_smacv2: + raise ImportError( + f"smacv2 python package was not found. Please install this dependency. " + f"More info: {self.git_url}." + ) from IMPORT_ERR + kwargs["map_name"] = map_name + kwargs["capability_config"] = capability_config + kwargs["seed"] = seed + + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + if "map_name" not in kwargs: + raise TypeError("Expected 'map_name' to be part of kwargs") + + def _build_env( + self, + map_name: str, + capability_config: Optional[Dict] = None, + seed: Optional[int] = None, + **kwargs, + ) -> "smacv2.env.StarCraft2Env": + + if capability_config is not None: + env = smacv2.env.StarCraftCapabilityEnvWrapper( + capability_config=capability_config, map_name=map_name, seed=seed + ) + else: + env = smacv2.env.StarCraft2Env( + capability_config=capability_config, map_name=map_name, seed=seed + ) + + return super()._build_env(env) From d8913ec7d35daf20998cd5d3a2f7b1b44b1b6730 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 11:07:29 +0100 Subject: [PATCH 04/35] amend Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 185 ++++++++++++++++++++---------------- torchrl/envs/utils.py | 5 +- 2 files changed, 106 insertions(+), 84 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 6cc90bbbae1..e7759638002 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -5,11 +5,12 @@ from typing import Dict, Optional import torch +from tensordict import TensorDict, TensorDictBase from torchrl.data import ( CompositeSpec, DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, + OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper @@ -67,10 +68,12 @@ class SMACv2Wrapper(_EnvWrapper): def __init__( self, env: "smacv2.env.StarCraft2Env" = None, + categorical_actions: bool = False, **kwargs, ): if env is not None: kwargs["env"] = env + self.categorical_actions = categorical_actions super().__init__(**kwargs) @@ -119,13 +122,20 @@ def _init_env(self) -> None: self.observation_spec = self._make_observation_spec() def _make_action_spec(self) -> CompositeSpec: - # TODO masking - # mask = torch.tensor(env.get_avail_actions(), dtype=torch.bool, device=self.device) - action_spec = MultiOneHotDiscreteTensorSpec( - [self.n_actions], - shape=torch.Size([self.n_agents, self.n_actions]), - device=self.device, - ) + if self.categorical_actions: + action_spec = DiscreteTensorSpec( + self.n_actions, + shape=torch.Size((self.n_agents,)), + device=self.device, + dtype=torch.long, + ) + else: + action_spec = OneHotDiscreteTensorSpec( + self.n_actions, + shape=torch.Size((self.n_agents, self.n_actions)), + device=self.device, + dtype=torch.long, + ) spec = CompositeSpec( { "agents": CompositeSpec( @@ -142,89 +152,100 @@ def _make_observation_spec(self) -> CompositeSpec: spec = CompositeSpec( { "agents": CompositeSpec( - {"observation": obs_spec}, shape=torch.Size((self.n_agents,)) + {"observation": obs_spec}, + shape=torch.Size((self.n_agents,), dtype=torch.float32), ), "state": UnboundedContinuousTensorSpec( torch.Size([self.n_agents, self.get_state_size()]), device=self.device, + dtype=torch.float32, ), } ) return spec def _set_seed(self, seed: Optional[int]): - raise NotImplementedError( - "Seed cannot be changed once environment was created." + if seed is not None: + raise NotImplementedError( + "Seed cannot be changed once environment was created." + ) + + def get_obs(self): + obs = self.get_obs() + return self._to_tensor(obs) + + def get_state(self): + state = self.get_state() + return self._to_tensor(state) + + def _to_tensor(self, value): + return torch.tensor(value, device=self.device, dtype=torch.float32) + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + + obs, state = self._env.reset() + + # collect outputs + obs = self._to_tensor(obs) + state = self._to_tensor(state) + + mask = self.get_action_mask() + + # build results + agents_td = TensorDict({"observation": obs}, batch_size=(self.n_agents,)) + tensordict_out = TensorDict( + source={"agents": agents_td, "state": state, "mask": mask}, + batch_size=(), + device=self.device, ) - # def _action_transform(self, action: torch.Tensor): - # action_np = self.action_spec.to_numpy(action) - # return action_np - # - # def _read_state(self, state: np.ndarray) -> torch.Tensor: - # return self.state_spec.encode( - # torch.Tensor(state, device=self.device).expand(*self.state_spec.shape) - # ) - # - # - # def _reset( - # self, tensordict: Optional[TensorDictBase] = None, **kwargs - # ) -> TensorDictBase: - # env: smac.env.StarCraft2Env = self._env - # obs, state = env.reset() - # - # # collect outputs - # obs_dict = self.read_obs(obs) - # state = self._read_state(state) - # self._is_done = torch.zeros(self.batch_size, dtype=torch.bool) - # - # # build results - # tensordict_out = TensorDict( - # source=obs_dict, - # batch_size=self.batch_size, - # device=self.device, - # ) - # tensordict_out.set("done", self._is_done) - # tensordict_out["state"] = state - # - # self.input_spec = self._make_input_spec(env) - # - # return tensordict_out - # - # def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - # env: smac.env.StarCraft2Env = self._env - # - # # perform actions - # action = tensordict.get("action") # this is a list of actions for each agent - # action_np = self._action_transform(action) - # - # # Actions are validated by the environment. - # reward, done, info = env.step(action_np) - # - # # collect outputs - # obs_dict = self.read_obs(env.get_obs()) - # # TODO: add centralized flag? - # state = self._read_state(env.get_state()) - # - # reward = self._to_tensor(reward, dtype=self.reward_spec.dtype).expand( - # self.batch_size - # ) - # done = self._to_tensor(done, dtype=torch.bool).expand(self.batch_size) - # - # # build results - # tensordict_out = TensorDict( - # source=obs_dict, - # batch_size=tensordict.batch_size, - # device=self.device, - # ) - # tensordict_out.set("reward", reward) - # tensordict_out.set("done", done) - # tensordict_out["state"] = state - # - # # Update available actions mask. - # self.input_spec = self._make_input_spec(env) - # - # return tensordict_out + self.get_action_mask() + + return tensordict_out + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # perform actions + action = tensordict.get(("agents", "action")) + action_np = self.action_spec.to_numpy(action) + + # Actions are validated by the environment. + reward, done, info = self._env.step(action_np) + + # collect outputs + obs = self.get_obs() + state = self.get_state() + + reward = torch.tensor(reward, device=self.device, dtype=torch.float32) + done = torch.tensor(done, device=self.device, dtype=torch.bool) + + mask = self.get_action_mask() + + # build results + agents_td = TensorDict( + {"observation": obs, "mask": mask}, batch_size=(self.n_agents,) + ) + + tensordict_out = TensorDict( + source={ + "next": { + "agents": agents_td, + "state": state, + "reward": reward, + "done": done, + } + }, + batch_size=(), + device=self.device, + ) + + return tensordict_out + + def get_action_mask(self): + return torch.tensor( + self.get_avail_actions(), dtype=torch.bool, device=self.device + ) class SMACv2Env(SMACv2Wrapper): @@ -241,6 +262,7 @@ def __init__( map_name: str, capability_config: Optional[Dict] = None, seed: Optional[int] = None, + categorical_actions: bool = False, **kwargs, ): if not _has_smacv2: @@ -251,6 +273,7 @@ def __init__( kwargs["map_name"] = map_name kwargs["capability_config"] = capability_config kwargs["seed"] = seed + kwargs["categorical_actions"] = categorical_actions super().__init__(**kwargs) @@ -271,8 +294,6 @@ def _build_env( capability_config=capability_config, map_name=map_name, seed=seed ) else: - env = smacv2.env.StarCraft2Env( - capability_config=capability_config, map_name=map_name, seed=seed - ) + env = smacv2.env.StarCraft2Env(map_name=map_name, seed=seed) return super()._build_env(env) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 1f671dda1b2..8808e8190e0 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -406,8 +406,9 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): of an experiment and as such should be kept out of training scripts. """ - torch.manual_seed(seed) - env.set_seed(seed) + if seed is not None: + torch.manual_seed(seed) + env.set_seed(seed) fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous) From 266a84d4143b26b5f0704879effe48c6ea2c4b1d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 11:35:07 +0100 Subject: [PATCH 05/35] amend Signed-off-by: Matteo Bettini --- docs/source/reference/envs.rst | 1 + test/test_libs.py | 25 ++++ torchrl/data/tensor_specs.py | 141 ++++++++++++++++++---- torchrl/envs/__init__.py | 1 + torchrl/envs/libs/smacv2.py | 38 +++--- torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 122 +++++++++++++++++++ torchrl/modules/distributions/discrete.py | 7 +- 8 files changed, 295 insertions(+), 41 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 1150eea3e99..c069670e2f3 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -445,6 +445,7 @@ to be able to create this other composition: Transform TransformedEnv + ActionMask BinarizeReward CatFrames CatTensors diff --git a/test/test_libs.py b/test/test_libs.py index dc03c092c68..93ab6dad0b2 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -73,6 +73,7 @@ _has_sklearn = importlib.util.find_spec("sklearn") is not None +from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env if _has_gym: try: @@ -1615,6 +1616,30 @@ def test_env(self, task, num_envs, device): # break +@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") +class TestSmacv2: + def test(self): + distribution_config = { + "n_units": 5, + "n_enemies": 10, + "team_gen": { + "dist_type": "weighted_teams", + "unit_types": ["marine", "marauder", "medivac"], + "exception_unit_types": ["medivac"], + "weights": [0.45, 0.55, 0.0], + "observe": True, + }, + } + env = SMACv2Env( + map_name="10gen_terran", + capability_config=distribution_config, + seed=2, + render=False, + ) + check_env_specs(env, seed=None) + # env.reset() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b11a2e2ebef..88567a7c607 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1119,6 +1119,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, use_register: bool = False, + mask: torch.Tensor | None = None, ): dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register @@ -1133,6 +1134,17 @@ def __init__( f"Got n={space.n} and shape={shape}." ) super().__init__(shape, space, device, dtype, "discrete") + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(self.shape) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): @@ -1172,8 +1184,15 @@ def expand(self, *shape): f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) + mask = self.mask + if mask is not None: + mask = mask.expand(shape) return self.__class__( - n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + mask=mask, ) def squeeze(self, dim=None): @@ -1185,13 +1204,16 @@ def squeeze(self, dim=None): shape = _squeezed_shape(self.shape, dim) if shape is None: return self - + mask = self.mask + if mask is not None: + mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=mask, ) def unsqueeze(self, dim: int): @@ -1201,12 +1223,16 @@ def unsqueeze(self, dim: int): ) shape = _unsqueezed_shape(self.shape, dim) + mask = self.mask + if mask is not None: + mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, + mask=mask, ) def rand(self, shape=None) -> torch.Tensor: @@ -1214,9 +1240,19 @@ def rand(self, shape=None) -> torch.Tensor: shape = self.shape[:-1] else: shape = torch.Size([*shape, *self.shape[:-1]]) - n = self.space.n - m = torch.randint(n, (*shape, 1), device=self.device) - out = torch.zeros((*shape, n), device=self.device, dtype=self.dtype) + mask = self.mask + if mask is None: + n = self.space.n + m = torch.randint(n, (*shape, 1), device=self.device) + else: + mask = mask.expand(*shape, mask.shape[-1]) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + m = torch.multinomial(mask_flat.float(), 1).reshape(*shape_out, 1) + out = torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype) out.scatter_(-1, m, 1) return out @@ -1288,13 +1324,29 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) def _project(self, val: torch.Tensor) -> torch.Tensor: - # idx = val.sum(-1) != 1 - out = torch.nn.functional.gumbel_softmax(val.to(torch.float)) - out = (out == out.max(dim=-1, keepdim=True)[0]).to(torch.long) - return out + if self.mask is None: + out = torch.multinomial(val.to(torch.float), 1) + out = (out == out.max(dim=-1, keepdim=True)[0]).to(self.dtype) + return out + shape = self.mask.shape + shape = torch.broadcast_shapes(shape, val.shape) + mask_expand = self.mask.expand(shape) + gathered = mask_expand & val + oob = ~gathered.any(-1) + new_val = torch.multinomial(mask_expand[oob].float(), 1) + val = val.clone() + val[oob] = 0 + val[oob] = torch.scatter(val[oob], -1, new_val, 1) + return val def is_in(self, val: torch.Tensor) -> bool: - return (val.sum(-1) == 1).all() + if self.mask is None: + return (val.sum(-1) == 1).all() + shape = self.mask.shape + shape = torch.broadcast_shapes(shape, val.shape) + mask_expand = self.mask.expand(shape) + gathered = mask_expand & val + return gathered.any(-1).all() def __eq__(self, other): return ( @@ -2044,34 +2096,73 @@ class DiscreteTensorSpec(TensorSpec): def __init__( self, n: int, - shape: Optional[torch.Size] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, + shape: torch.Size | None = None, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype = torch.long, + mask: torch.Tensor | None = None, ): if shape is None: shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) super().__init__(shape, space, device, dtype, domain="discrete") + self.update_mask(mask) + + def update_mask(self, mask): + if mask is not None: + try: + mask = mask.expand(*self.shape, self.space.n) + except RuntimeError as err: + raise RuntimeError("Cannot expand mask to the desired shape.") from err + if mask.dtype != torch.bool: + raise ValueError("Only boolean masks are accepted.") + self.mask = mask def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - return torch.randint( - 0, - self.space.n, - torch.Size([*shape, *self.shape]), - device=self.device, - dtype=self.dtype, - ) + if self.mask is None: + return torch.randint( + 0, + self.space.n, + torch.Size([*shape, *self.shape]), + device=self.device, + dtype=self.dtype, + ) + mask = self.mask + mask = mask.expand(*shape, *mask.shape) + if mask.ndim > 2: + mask_flat = torch.flatten(mask, 0, -2) + else: + mask_flat = mask + shape_out = mask.shape[:-1] + print("mask in spec", mask_flat) + out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + print("out", out) + return out def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (torch.int, torch.long): val = torch.round(val) - return val.clamp_(min=0, max=self.space.n - 1) + if self.mask is None: + return val.clamp_(min=0, max=self.space.n - 1) + shape = self.mask.shape + shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + mask_expand = self.mask.expand(shape) + gathered = mask_expand.gather(-1, val.unsqueeze(-1)) + oob = ~gathered.all(-1) + new_val = torch.multinomial(mask_expand[oob].float(), 1).squeeze(-1) + val = torch.masked_scatter(val, oob, new_val) + return val def is_in(self, val: torch.Tensor) -> bool: - return (0 <= val).all() and (val < self.space.n).all() + if self.mask is None: + return (0 <= val).all() and (val < self.space.n).all() + shape = self.mask.shape + shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + mask_expand = self.mask.expand(shape) + gathered = mask_expand.gather(-1, val.unsqueeze(-1)) + return gathered.all() def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" @@ -2118,7 +2209,11 @@ def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" shape = [*self.shape, self.space.n] return OneHotDiscreteTensorSpec( - n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + n=self.space.n, + shape=shape, + device=self.device, + dtype=self.dtype, + mask=self.mask, ) def expand(self, *shape): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 9422e60a5d5..551c09d8415 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -8,6 +8,7 @@ from .gym_like import default_info_dict_reader, GymLikeEnv from .model_based import ModelBasedEnvBase from .transforms import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index e7759638002..deaf6ad4c55 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -147,16 +147,24 @@ def _make_action_spec(self) -> CompositeSpec: def _make_observation_spec(self) -> CompositeSpec: obs_spec = UnboundedContinuousTensorSpec( - torch.Size([self.n_agents, self.get_obs_size()]), device=self.device + torch.Size([self.n_agents, self.get_obs_size()]), + device=self.device, + dtype=torch.float32, + ) + mask_spec = DiscreteTensorSpec( + 2, + torch.Size([self.n_agents, self.n_actions]), + device=self.device, + dtype=torch.bool, ) spec = CompositeSpec( { "agents": CompositeSpec( - {"observation": obs_spec}, - shape=torch.Size((self.n_agents,), dtype=torch.float32), + {"observation": obs_spec, "mask": mask_spec}, + shape=torch.Size((self.n_agents,)), ), "state": UnboundedContinuousTensorSpec( - torch.Size([self.n_agents, self.get_state_size()]), + torch.Size((self.get_state_size(),)), device=self.device, dtype=torch.float32, ), @@ -171,11 +179,11 @@ def _set_seed(self, seed: Optional[int]): ) def get_obs(self): - obs = self.get_obs() + obs = self._env.get_obs() return self._to_tensor(obs) def get_state(self): - state = self.get_state() + state = self._env.get_state() return self._to_tensor(state) def _to_tensor(self, value): @@ -191,18 +199,18 @@ def _reset( obs = self._to_tensor(obs) state = self._to_tensor(state) - mask = self.get_action_mask() + mask = self.update_action_mask() # build results - agents_td = TensorDict({"observation": obs}, batch_size=(self.n_agents,)) + agents_td = TensorDict( + {"observation": obs, "mask": mask}, batch_size=(self.n_agents,) + ) tensordict_out = TensorDict( - source={"agents": agents_td, "state": state, "mask": mask}, + source={"agents": agents_td, "state": state}, batch_size=(), device=self.device, ) - self.get_action_mask() - return tensordict_out def _step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -220,7 +228,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = torch.tensor(reward, device=self.device, dtype=torch.float32) done = torch.tensor(done, device=self.device, dtype=torch.bool) - mask = self.get_action_mask() + mask = self.update_action_mask() # build results agents_td = TensorDict( @@ -242,10 +250,12 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out - def get_action_mask(self): - return torch.tensor( + def update_action_mask(self): + mask = torch.tensor( self.get_avail_actions(), dtype=torch.bool, device=self.device ) + self.action_spec.update_mask(mask) + return mask class SMACv2Env(SMACv2Wrapper): diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index f14df66f9ed..3e25adac0bc 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -6,6 +6,7 @@ from .r3m import R3MTransform from .rlhf import KLRewardTransform from .transforms import ( + ActionMask, BinarizeReward, CatFrames, CatTensors, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 268a2e258a7..0d29e693145 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -368,8 +368,49 @@ def clone(self): self_copy.__dict__.update(state) return self_copy + @property + def container(self): + """Returns the env containing the transform. + + Examples: + >>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter())) + >>> env.transform[0].container is env + True + """ + if "_container" not in self.__dict__: + raise AttributeError("transform parent uninitialized") + container = self.__dict__["_container"] + if container is None: + return container + while not isinstance(container, EnvBase): + # if it's not an env, it should be a Compose transform + if not isinstance(container, Compose): + raise ValueError( + "A transform parent must be either another Compose transform or an environment object." + ) + compose = container + container = compose.__dict__.get("_container", None) + return container + @property def parent(self) -> Optional[EnvBase]: + """Returns the parent env of the transform. + + The parent env is the env that contains all the transforms up until the current one. + + Examples: + >>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter + >>> from torchrl.envs.libs.gym import GymEnv + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter())) + >>> env.transform[1].parent + TransformedEnv( + env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu), + transform=Compose( + RewardSum(keys=['reward']))) + + """ if self.__dict__.get("_parent", None) is None: if "_container" not in self.__dict__: raise AttributeError("transform parent uninitialized") @@ -4619,3 +4660,84 @@ def _inv_apply_transform( def set_container(self, container): if isinstance(container, EnvBase) or container.parent is not None: raise ValueError(self.ENV_ERR) + + +class ActionMask(Transform): + """An adaptive action masker. + + This transform reads the mask from the input tensordict after the step is executed, + and adapts the mask of the one-hot / categorical action spec. + + .. note:: This transform will fail when used without an environment. + + Examples: + >>> import torch + >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec + >>> from torchrl.envs.transforms import ActionMask, TransformedEnv, EnvBase + >>> class MaskedEnv(EnvBase): + ... def __init__(self, *args, **kwargs): + ... super().__init__(*args, **kwargs) + ... self.action_spec = DiscreteTensorSpec(4) + ... self.state_spec = CompositeSpec(mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)) + ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) + ... self.reward_spec = UnboundedContinuousTensorSpec(1) + ... + ... def _reset(self, data): + ... td = self.observation_spec.rand() + ... td.update(torch.ones_like(self.state_spec.rand())) + ... return td + ... + ... def _step(self, data): + ... td = self.observation_spec.rand() + ... mask = data.get("mask") + ... action = data.get("action") + ... mask = mask.scatter(-1, action.unsqueeze(-1), 0) + ... + ... td.set("mask", mask) + ... td.set("reward", self.reward_spec.rand()) + ... td.set("done", ~mask.any().view(1)) + ... return td.empty().set("next", td) + ... + ... def _set_seed(self, seed): + ... return seed + >>> base_env = MaskedEnv() + >>> env = TransformedEnv(base_env, ActionMask()) + >>> r = env.rollout(10) + >>> env = TransformedEnv(base_env, ActionMask()) + >>> r = env.rollout(10) + >>> r["mask"] + + """ + + def __init__(self, action_key="action", mask_key="mask"): + if not isinstance(action_key, str): + raise ValueError( + f"The action key must be a string. Got {type(action_key)} instead." + ) + if not isinstance(mask_key, str): + raise ValueError( + f"The mask key must be a string. Got {type(mask_key)} instead." + ) + super().__init__( + in_keys=[action_key, mask_key], out_keys=[], in_keys_inv=[], out_keys_inv=[] + ) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + mask = tensordict.get(self.in_keys[1]) + parent = self.parent + if parent is None: + raise RuntimeError( + f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment." + ) + action_spec = self._container.action_spec + if not isinstance(action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec)): + raise ValueError( + f"The action spec must be one of (OneHotDiscreteTensorSpec, DiscreteTensorSpec). Got {type(action_spec)} instead." + ) + action_spec.update_mask(mask) + return tensordict + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + action_spec = self._container.action_spec + action_spec.update_mask(tensordict.get(self.in_keys[1], None)) + return action_spec diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 52d52e3113e..9562dab1643 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -201,13 +201,12 @@ def sample( # Python 3.7 doesn't support math.prod # outer_dim = prod(sample_shape) # inner_dim = prod(self._mask.size()[:-1]) - outer_dim = torch.empty(sample_shape, device="meta").numel() - inner_dim = self._mask.numel() // self._mask.size(-1) + outer_dim = sample_shape.numel() + inner_dim = self._mask.shape[:-1].numel() idx_3d = self._mask.expand(outer_dim, inner_dim, -1) ret = idx_3d.gather(dim=-1, index=ret.view(outer_dim, inner_dim, 1)) - return ret.view(size) + return ret.reshape(size) - # # # TODO: Improve performance here. def log_prob(self, value: torch.Tensor) -> torch.Tensor: if not self._sparse_mask: return super().log_prob(value) From 7e6b9c90dd478bc91be9de9da64fb44372237329 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 11:36:52 +0100 Subject: [PATCH 06/35] amend Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index deaf6ad4c55..e67f4725f62 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -121,6 +121,8 @@ def _init_env(self) -> None: self.action_spec = self._make_action_spec() self.observation_spec = self._make_observation_spec() + self.update_action_mask() + def _make_action_spec(self) -> CompositeSpec: if self.categorical_actions: action_spec = DiscreteTensorSpec( From 646e2339e582c3b4c8a65fd1da95782c02a23e4e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 11:58:13 +0100 Subject: [PATCH 07/35] add info Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 41 ++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index e67f4725f62..f1839ad77b8 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -8,6 +8,7 @@ from tensordict import TensorDict, TensorDictBase from torchrl.data import ( + BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -153,6 +154,30 @@ def _make_observation_spec(self) -> CompositeSpec: device=self.device, dtype=torch.float32, ) + info_spec = CompositeSpec( + { + "battle_won": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device + ), + "episode_limit": DiscreteTensorSpec( + 2, dtype=torch.bool, device=self.device + ), + "dead_allies": BoundedTensorSpec( + minimum=0, + maximum=self.n_agents, + dtype=torch.long, + device=self.device, + shape=(), + ), + "dead_enemies": BoundedTensorSpec( + minimum=0, + maximum=self.n_enemies, + dtype=torch.long, + device=self.device, + shape=(), + ), + } + ) mask_spec = DiscreteTensorSpec( 2, torch.Size([self.n_agents, self.n_actions]), @@ -170,6 +195,7 @@ def _make_observation_spec(self) -> CompositeSpec: device=self.device, dtype=torch.float32, ), + "info": info_spec, } ) return spec @@ -200,6 +226,7 @@ def _reset( # collect outputs obs = self._to_tensor(obs) state = self._to_tensor(state) + info = self.observation_spec["info"].zero() mask = self.update_action_mask() @@ -208,7 +235,7 @@ def _reset( {"observation": obs, "mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( - source={"agents": agents_td, "state": state}, + source={"agents": agents_td, "state": state, "info": info}, batch_size=(), device=self.device, ) @@ -226,9 +253,16 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # collect outputs obs = self.get_obs() state = self.get_state() + info = self.observation_spec["info"].encode(info) + if "episode_limit" not in info.keys(): + info["episode_limit"] = self.observation_spec["info"][ + "episode_limit" + ].zero() - reward = torch.tensor(reward, device=self.device, dtype=torch.float32) - done = torch.tensor(done, device=self.device, dtype=torch.bool) + reward = torch.tensor( + reward, device=self.device, dtype=torch.float32 + ).unsqueeze(-1) + done = torch.tensor(done, device=self.device, dtype=torch.bool).unsqueeze(-1) mask = self.update_action_mask() @@ -242,6 +276,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: "next": { "agents": agents_td, "state": state, + "info": info, "reward": reward, "done": done, } From e9f52575d64d6bc1ad63925b62c96c967cfbac92 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 14:02:56 +0100 Subject: [PATCH 08/35] add ci Signed-off-by: Matteo Bettini --- .../linux_libs/scripts_smacv2/environment.yml | 21 ++ .../linux_libs/scripts_smacv2/install.sh | 46 +++ .../linux_libs/scripts_smacv2/post_process.sh | 6 + .../scripts_smacv2/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_smacv2/run_test.sh | 30 ++ .../linux_libs/scripts_smacv2/setup_env.sh | 63 ++++ test/test_libs.py | 1 - 7 files changed, 522 insertions(+), 1 deletion(-) create mode 100644 .circleci/unittest/linux_libs/scripts_smacv2/environment.yml create mode 100755 .circleci/unittest/linux_libs/scripts_smacv2/install.sh create mode 100755 .circleci/unittest/linux_libs/scripts_smacv2/post_process.sh create mode 100755 .circleci/unittest/linux_libs/scripts_smacv2/run-clang-format.py create mode 100755 .circleci/unittest/linux_libs/scripts_smacv2/run_test.sh create mode 100755 .circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/environment.yml b/.circleci/unittest/linux_libs/scripts_smacv2/environment.yml new file mode 100644 index 00000000000..d1e1e1f5edc --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smacv2/environment.yml @@ -0,0 +1,21 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - cloudpickle + - gym + - gym-notices + - importlib-metadata + - zipp + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - numpy==1.23.0 + - git+https://github.com/oxwhirl/smacv2.git diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/install.sh b/.circleci/unittest/linux_libs/scripts_smacv2/install.sh new file mode 100755 index 00000000000..cb36c7cc48a --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smacv2/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git + +# smoke test +python -c "import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/post_process.sh b/.circleci/unittest/linux_libs/scripts_smacv2/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smacv2/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/run-clang-format.py b/.circleci/unittest/linux_libs/scripts_smacv2/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smacv2/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh b/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh new file mode 100755 index 00000000000..6356e7cb4ed --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# this workflow only tests the libs +python -c "import smacv2" + +python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestSmacv2 --error-for-skips +coverage combine +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh new file mode 100755 index 00000000000..e13af74d8df --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +# 5. Install StarCraft 2 with SMACv2 maps +# SC2PATH is set in run_test.sh +printf "* Installing StarCraft 2 and SMACv2 maps into '${root_dir}/smacv2/StarCraftII'\n" +mkdir "${root_dir}/smacv2" +cd "${root_dir}/smacv2" +# TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. +# e.g adding this into the image learn( which one is used and how it is maintained) +wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip +# The archive contains StarCraftII folder. Password comes from the documentation. +unzip -qo -P iagreetotheeula SC2.4.10.zip +rm -rf SC2.4.10.zip +# Install Maps +wget https://github.com/oxwhirl/smacv2/releases/download/maps/SMAC_Maps.zip +unzip -qo SMAC_Maps.zip -d ./StarCraftII/Maps +printf "StarCraft II and SMAC are installed." diff --git a/test/test_libs.py b/test/test_libs.py index 93ab6dad0b2..8150843b2df 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1637,7 +1637,6 @@ def test(self): render=False, ) check_env_specs(env, seed=None) - # env.reset() if __name__ == "__main__": From 60a156397211231517f87167bf34fbcb716e3b2c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 14:08:15 +0100 Subject: [PATCH 09/35] add ci Signed-off-by: Matteo Bettini --- .github/workflows/test-linux-smacv2.yml | 40 +++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/test-linux-smacv2.yml diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml new file mode 100644 index 00000000000..35dd4ff2409 --- /dev/null +++ b/.github/workflows/test-linux-smacv2.yml @@ -0,0 +1,40 @@ +name: SMACv2 Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="11.7" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + nvidia-smi + + bash .circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh + bash .circleci/unittest/linux_libs/scripts_smacv2/install.sh + bash .circleci/unittest/linux_libs/scripts_smacv2/run_test.sh + bash .circleci/unittest/linux_libs/scripts_smacv2/post_process.sh From 79ce182d72f8c45dbe817053243144bd1333c7c5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 15:12:01 +0100 Subject: [PATCH 10/35] amend Signed-off-by: Matteo Bettini --- test/test_libs.py | 1 + torchrl/envs/libs/smacv2.py | 72 ++++++++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8150843b2df..5b841e02db0 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1637,6 +1637,7 @@ def test(self): render=False, ) check_env_specs(env, seed=None) + env.close() if __name__ == "__main__": diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index f1839ad77b8..350a12ac081 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -49,11 +49,11 @@ class SMACv2Wrapper(_EnvWrapper): done: Tensor(torch.Size([1]), dtype=torch.bool), next: TensorDict( fields={ - observation: Tensor(torch.Size([8, 80]), dtype=torch.float32)}, + obs: Tensor(torch.Size([8, 80]), dtype=torch.float32)}, batch_size=torch.Size([]), device=cpu, is_shared=False), - observation: Tensor(torch.Size([8, 80]), dtype=torch.float32), + obs: Tensor(torch.Size([8, 80]), dtype=torch.float32), reward: Tensor(torch.Size([1]), dtype=torch.float32)}, batch_size=torch.Size([]), device=cpu, @@ -101,7 +101,6 @@ def _build_env( return env def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: - # Extract specs from definition. self.reward_spec = UnboundedContinuousTensorSpec( shape=torch.Size((1,)), device=self.device, @@ -112,16 +111,11 @@ def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: dtype=torch.bool, device=self.device, ) - - # Specs that require initialized environment are built in _init_env. - - def _init_env(self) -> None: - self._env.reset() - - # Before extracting environment specific specs, env.reset() must be executed. self.action_spec = self._make_action_spec() self.observation_spec = self._make_observation_spec() + def _init_env(self) -> None: + self._env.reset() self.update_action_mask() def _make_action_spec(self) -> CompositeSpec: @@ -149,8 +143,10 @@ def _make_action_spec(self) -> CompositeSpec: return spec def _make_observation_spec(self) -> CompositeSpec: - obs_spec = UnboundedContinuousTensorSpec( - torch.Size([self.n_agents, self.get_obs_size()]), + obs_spec = BoundedTensorSpec( + minimum=-1.0, + maximum=1.0, + shape=torch.Size([self.n_agents, self.get_obs_size()]), device=self.device, dtype=torch.float32, ) @@ -187,11 +183,13 @@ def _make_observation_spec(self) -> CompositeSpec: spec = CompositeSpec( { "agents": CompositeSpec( - {"observation": obs_spec, "mask": mask_spec}, + {"obs": obs_spec, "action_mask": mask_spec}, shape=torch.Size((self.n_agents,)), ), - "state": UnboundedContinuousTensorSpec( - torch.Size((self.get_state_size(),)), + "state": BoundedTensorSpec( + minimum=-1.0, + maximum=1.0, + shape=torch.Size((self.get_state_size(),)), device=self.device, dtype=torch.float32, ), @@ -232,7 +230,7 @@ def _reset( # build results agents_td = TensorDict( - {"observation": obs, "mask": mask}, batch_size=(self.n_agents,) + {"obs": obs, "action_mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( source={"agents": agents_td, "state": state, "info": info}, @@ -268,7 +266,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # build results agents_td = TensorDict( - {"observation": obs, "mask": mask}, batch_size=(self.n_agents,) + {"obs": obs, "action_mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( @@ -294,6 +292,46 @@ def update_action_mask(self): self.action_spec.update_mask(mask) return mask + def close(self): + # Closes StarCraft II + self._env.close() + + def get_agent_type(self, agent_index: int) -> str: + """Get the agent type string. + + Given the agent index, get its unit type name. + + Args: + agent_index (int): the index of the agent to get the type of + + """ + if agent_index < 0 or agent_index >= self.n_agents: + raise ValueError(f"Agent index out of range, {self.n_agents} available") + + agent_info = self.agents[agent_index] + if agent_info.unit_type == self.marine_id: + agent_type = "marine" + elif agent_info.unit_type == self.marauder_id: + agent_type = "marauder" + elif agent_info.unit_type == self.medivac_id: + agent_type = "medivac" + elif agent_info.unit_type == self.hydralisk_id: + agent_type = "hydralisk" + elif agent_info.unit_type == self.zergling_id: + agent_type = "zergling" + elif agent_info.unit_type == self.baneling_id: + agent_type = "baneling" + elif agent_info.unit_type == self.stalker_id: + agent_type = "stalker" + elif agent_info.unit_type == self.colossus_id: + agent_type = "colossus" + elif agent_info.unit_type == self.zealot_id: + agent_type = "zealot" + else: + raise AssertionError(f"Agent type {agent_info.unit_type} unidentified") + + return agent_type + class SMACv2Env(SMACv2Wrapper): """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. From a88e13b4d2dc0c60374a8d72b6d386528138cc2e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 15:14:43 +0100 Subject: [PATCH 11/35] amend Signed-off-by: Matteo Bettini --- .circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh index e13af74d8df..e5fafeb6430 100755 --- a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh +++ b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh @@ -49,8 +49,7 @@ conda env update --file "${this_dir}/environment.yml" --prune # 5. Install StarCraft 2 with SMACv2 maps # SC2PATH is set in run_test.sh printf "* Installing StarCraft 2 and SMACv2 maps into '${root_dir}/smacv2/StarCraftII'\n" -mkdir "${root_dir}/smacv2" -cd "${root_dir}/smacv2" +cd "${root_dir}" # TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. # e.g adding this into the image learn( which one is used and how it is maintained) wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip From 28cd2b55877da967819437c0a3d6ce344d861c9a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 15:16:16 +0100 Subject: [PATCH 12/35] amend Signed-off-by: Matteo Bettini --- .circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh index e5fafeb6430..60943c7e2ee 100755 --- a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh +++ b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh @@ -47,8 +47,7 @@ pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune # 5. Install StarCraft 2 with SMACv2 maps -# SC2PATH is set in run_test.sh -printf "* Installing StarCraft 2 and SMACv2 maps into '${root_dir}/smacv2/StarCraftII'\n" +printf "* Installing StarCraft 2 and SMACv2 maps into '${root_dir}/StarCraftII'\n" cd "${root_dir}" # TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. # e.g adding this into the image learn( which one is used and how it is maintained) From d0cf059bf441ebb329dd612554c355e1d6531180 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 15:38:57 +0100 Subject: [PATCH 13/35] amend Signed-off-by: Matteo Bettini --- .../linux_libs/scripts_smacv2/run_test.sh | 1 + test/test_libs.py | 36 ++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh b/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh index 6356e7cb4ed..9e01d97bc1f 100755 --- a/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh +++ b/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh @@ -15,6 +15,7 @@ git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" env_dir="${root_dir}/env" lib_dir="${env_dir}/lib" +export SC2PATH="${root_dir}/StarCraftII" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir diff --git a/test/test_libs.py b/test/test_libs.py index 5b841e02db0..f210ac84787 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1618,23 +1618,49 @@ def test_env(self, task, num_envs, device): @pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") class TestSmacv2: - def test(self): + def test_env_procedural(self): distribution_config = { "n_units": 5, - "n_enemies": 10, + "n_enemies": 6, "team_gen": { "dist_type": "weighted_teams", "unit_types": ["marine", "marauder", "medivac"], "exception_unit_types": ["medivac"], - "weights": [0.45, 0.55, 0.0], + "weights": [0.5, 0.2, 0.3], "observe": True, }, + "start_positions": { + "dist_type": "surrounded_and_reflect", + "p": 0.5, + "n_enemies": 5, + "map_x": 32, + "map_y": 32, + }, } env = SMACv2Env( map_name="10gen_terran", capability_config=distribution_config, - seed=2, - render=False, + seed=0, + ) + check_env_specs(env, seed=None) + env.close() + + @pytest.mark.parametrize("map", ["MMM2", "3s_vs_5z"]) + def test_env(self, map: str): + env = SMACv2Env( + map_name=map, + seed=0, + ) + check_env_specs(env, seed=None) + env.close() + + def test_vec_env(self): + env = ParallelEnv( + num_workers=2, + create_env_fn=lambda: SMACv2Env( + map_name="3s_vs_5z", + seed=0, + ), ) check_env_specs(env, seed=None) env.close() From 9014bc28cc5882d506d109e046aff6e44ee5d342 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 16:07:44 +0100 Subject: [PATCH 14/35] amend Signed-off-by: Matteo Bettini --- .circleci/unittest/linux_libs/scripts_smacv2/run_test.sh | 1 + .../unittest/linux_libs/scripts_smacv2/setup_env.sh | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh b/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh index 9e01d97bc1f..cb7425cf640 100755 --- a/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh +++ b/.circleci/unittest/linux_libs/scripts_smacv2/run_test.sh @@ -16,6 +16,7 @@ root_dir="$(git rev-parse --show-toplevel)" env_dir="${root_dir}/env" lib_dir="${env_dir}/lib" export SC2PATH="${root_dir}/StarCraftII" +echo 'SC2PATH is set to ' "$SC2PATH" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh index 60943c7e2ee..06d560d793d 100755 --- a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh +++ b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh @@ -47,15 +47,18 @@ pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune # 5. Install StarCraft 2 with SMACv2 maps -printf "* Installing StarCraft 2 and SMACv2 maps into '${root_dir}/StarCraftII'\n" +starcraft_path="${root_dir}/StarCraftII" +map_dir="${starcraft_path}/Maps" +printf "* Installing StarCraft 2 and SMACv2 maps into ${starcraft_path}\n" cd "${root_dir}" # TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. # e.g adding this into the image learn( which one is used and how it is maintained) wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip # The archive contains StarCraftII folder. Password comes from the documentation. unzip -qo -P iagreetotheeula SC2.4.10.zip -rm -rf SC2.4.10.zip +mkdir -p "${map_dir}" # Install Maps wget https://github.com/oxwhirl/smacv2/releases/download/maps/SMAC_Maps.zip -unzip -qo SMAC_Maps.zip -d ./StarCraftII/Maps + +unzip -qo SMAC_Maps.zip -d ."${map_dir}" printf "StarCraft II and SMAC are installed." From 662dbb7f6058bc8b54597ab2c89f424db896ca10 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 16:11:28 +0100 Subject: [PATCH 15/35] amend Signed-off-by: Matteo Bettini --- .circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh index 06d560d793d..04080cc8932 100755 --- a/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh +++ b/.circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh @@ -59,6 +59,7 @@ unzip -qo -P iagreetotheeula SC2.4.10.zip mkdir -p "${map_dir}" # Install Maps wget https://github.com/oxwhirl/smacv2/releases/download/maps/SMAC_Maps.zip - -unzip -qo SMAC_Maps.zip -d ."${map_dir}" +unzip SMAC_Maps.zip +mkdir "${map_dir}/SMAC_Maps" +mv *.SC2Map "${map_dir}/SMAC_Maps" printf "StarCraft II and SMAC are installed." From 1b4325de70e758f80f8adaecda0a21debbdcafcf Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 21 Aug 2023 16:51:24 +0100 Subject: [PATCH 16/35] docs Signed-off-by: Matteo Bettini --- docs/source/reference/envs.rst | 2 + torchrl/envs/libs/smacv2.py | 275 ++++++++++++++++++++++++++++++--- 2 files changed, 258 insertions(+), 19 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c069670e2f3..56f045f72b7 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -624,5 +624,7 @@ the following function will return ``1`` when queried: jumanji.JumanjiEnv jumanji.JumanjiWrapper openml.OpenMLEnv + smacv2.SMACv2Wrapper + smacv2.SMACv2Env vmas.VmasEnv vmas.VmasWrapper diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 350a12ac081..96d44179a38 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -30,36 +30,146 @@ def _get_envs(): if not _has_smacv2: return [] - return smac_maps.get_smac_map_registry().keys() + return list(smac_maps.get_smac_map_registry().keys()) class SMACv2Wrapper(_EnvWrapper): """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. + To install the environment follow the following `guide `__. + Examples: - >>> env = smac.env.StarCraft2Env("8m") - >>> env = SMACv2Wrapper(env) - >>> td = env.reset() - >>> td["action"] = env.action_spec.rand() - >>> td = env.step(td) - >>> print(td) + >>> from torchrl.envs.libs.smacv2 import SMACv2Wrapper + >>> import smacv2 + >>> print(SMACv2Wrapper.available_envs) + ['10gen_terran', '10gen_zerg', '10gen_protoss', '3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', '10m_vs_11m', + '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', + '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] + >>> # You can use old SMAC maps + >>> env = SMACv2Wrapper(smacv2.env.StarCraft2Env(map_name="MMM2")) + >>> print(env.rollout(5) TensorDict( fields={ - action: Tensor(torch.Size([8, 14]), dtype=torch.int64), - done: Tensor(torch.Size([1]), dtype=torch.bool), + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), next: TensorDict( fields={ - obs: Tensor(torch.Size([8, 80]), dtype=torch.float32)}, - batch_size=torch.Size([]), + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), device=cpu, is_shared=False), - obs: Tensor(torch.Size([8, 80]), dtype=torch.float32), - reward: Tensor(torch.Size([1]), dtype=torch.float32)}, - batch_size=torch.Size([]), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> # Or the new features for procedural generation + >>> distribution_config = { + ... "n_units": 5, + ... "n_enemies": 6, + ... "team_gen": { + ... "dist_type": "weighted_teams", + ... "unit_types": ["marine", "marauder", "medivac"], + ... "exception_unit_types": ["medivac"], + ... "weights": [0.5, 0.2, 0.3], + ... "observe": True, + ... }, + ... "start_positions": { + ... "dist_type": "surrounded_and_reflect", + ... "p": 0.5, + ... "n_enemies": 5, + ... "map_x": 32, + ... "map_y": 32, + ... }, + ... } + >>> env = SMACv2Wrapper( + ... smacv2.env.StarCraft2Env( + ... map_name="10gen_terran", + ... capability_config=distribution_config, + ... ) + ... ) + >>> print(env.rollout(4)) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), device=cpu, is_shared=False) - >>> print(env.available_envs) - ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] """ git_url = "https://github.com/oxwhirl/smacv2" @@ -336,10 +446,137 @@ def get_agent_type(self, agent_index: int) -> str: class SMACv2Env(SMACv2Wrapper): """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. + To install the environment follow the following `guide `__. + Examples: - >>> env = SMACv2Env(map_name="8m") - >>> print(env.available_envs) - ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] + >>> from torchrl.envs.libs.smacv2 import SMACv2Env + >>> print(SMACv2Env.available_envs) + ['10gen_terran', '10gen_zerg', '10gen_protoss', '3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', '10m_vs_11m', + '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', + '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] + >>> # You can use old SMAC maps + >>> env = SMACv2Env(map_name="MMM2") + >>> print(env.rollout(5) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 10]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([5, 322]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> # Or the new features for procedural generation + >>> distribution_config = { + ... "n_units": 5, + ... "n_enemies": 6, + ... "team_gen": { + ... "dist_type": "weighted_teams", + ... "unit_types": ["marine", "marauder", "medivac"], + ... "exception_unit_types": ["medivac"], + ... "weights": [0.5, 0.2, 0.3], + ... "observe": True, + ... }, + ... "start_positions": { + ... "dist_type": "surrounded_and_reflect", + ... "p": 0.5, + ... "n_enemies": 5, + ... "map_x": 32, + ... "map_y": 32, + ... }, + ... } + >>> env = SMACv2Env( + ... map_name="10gen_terran", + ... capability_config=distribution_config, + ... ) + >>> print(env.rollout(4)) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + battle_won: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False), + dead_allies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + dead_enemies: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False), + episode_limit: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False), + state: Tensor(shape=torch.Size([4, 131]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=cpu, + is_shared=False) """ def __init__( From b6ea62730ff0aa7ef7d3d004ccd8472a8ac0c90f Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 22 Aug 2023 13:59:48 +0100 Subject: [PATCH 17/35] add tests Signed-off-by: Matteo Bettini --- test/test_libs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index f210ac84787..ef42e4c3a02 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1645,10 +1645,12 @@ def test_env_procedural(self): check_env_specs(env, seed=None) env.close() + @pytest.mark.parametrize("categorical_actions", [True, False]) @pytest.mark.parametrize("map", ["MMM2", "3s_vs_5z"]) - def test_env(self, map: str): + def test_env(self, map: str, categorical_actions): env = SMACv2Env( map_name=map, + categorical_actions=categorical_actions, seed=0, ) check_env_specs(env, seed=None) From 9d9eb12a8a2abc4ed7a150015c4191f32d755172 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 29 Aug 2023 09:12:13 +0100 Subject: [PATCH 18/35] add group map Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 96d44179a38..639f0c6a48c 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -54,7 +54,7 @@ class SMACv2Wrapper(_EnvWrapper): fields={ action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), @@ -73,7 +73,7 @@ class SMACv2Wrapper(_EnvWrapper): agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), @@ -128,7 +128,7 @@ class SMACv2Wrapper(_EnvWrapper): fields={ action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), @@ -147,7 +147,7 @@ class SMACv2Wrapper(_EnvWrapper): agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), @@ -211,6 +211,7 @@ def _build_env( return env def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: + self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} self.reward_spec = UnboundedContinuousTensorSpec( shape=torch.Size((1,)), device=self.device, @@ -293,7 +294,7 @@ def _make_observation_spec(self) -> CompositeSpec: spec = CompositeSpec( { "agents": CompositeSpec( - {"obs": obs_spec, "action_mask": mask_spec}, + {"observation": obs_spec, "action_mask": mask_spec}, shape=torch.Size((self.n_agents,)), ), "state": BoundedTensorSpec( @@ -340,7 +341,7 @@ def _reset( # build results agents_td = TensorDict( - {"obs": obs, "action_mask": mask}, batch_size=(self.n_agents,) + {"observation": obs, "action_mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( source={"agents": agents_td, "state": state, "info": info}, @@ -376,7 +377,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # build results agents_td = TensorDict( - {"obs": obs, "action_mask": mask}, batch_size=(self.n_agents,) + {"observation": obs, "action_mask": mask}, batch_size=(self.n_agents,) ) tensordict_out = TensorDict( @@ -463,7 +464,7 @@ class SMACv2Env(SMACv2Wrapper): fields={ action: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), @@ -482,7 +483,7 @@ class SMACv2Env(SMACv2Wrapper): agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([5, 10, 18]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([5, 10, 176]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 10]), device=cpu, is_shared=False), @@ -535,7 +536,7 @@ class SMACv2Env(SMACv2Wrapper): fields={ action: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.int64, is_shared=False), action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), @@ -554,7 +555,7 @@ class SMACv2Env(SMACv2Wrapper): agents: TensorDict( fields={ action_mask: Tensor(shape=torch.Size([4, 5, 12]), device=cpu, dtype=torch.bool, is_shared=False), - obs: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([4, 5, 88]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=cpu, is_shared=False), From be12ac179721601454065438c4467555f2d9a6c3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 09:04:30 +0100 Subject: [PATCH 19/35] fixes Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 639f0c6a48c..2829deeb165 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -382,13 +382,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict( source={ - "next": { - "agents": agents_td, - "state": state, - "info": info, - "reward": reward, - "done": done, - } + "agents": agents_td, + "state": state, + "info": info, + "reward": reward, + "done": done, }, batch_size=(), device=self.device, @@ -614,9 +612,12 @@ def _build_env( if capability_config is not None: env = smacv2.env.StarCraftCapabilityEnvWrapper( - capability_config=capability_config, map_name=map_name, seed=seed + capability_config=capability_config, + map_name=map_name, + seed=seed, + **kwargs, ) else: - env = smacv2.env.StarCraft2Env(map_name=map_name, seed=seed) + env = smacv2.env.StarCraft2Env(map_name=map_name, seed=seed, **kwargs) return super()._build_env(env) From 74d7449f76f2163aeed4a68dd30e89267717e001 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 09:06:42 +0100 Subject: [PATCH 20/35] fix import Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 2829deeb165..764c618220b 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib from typing import Dict, Optional import torch @@ -16,16 +17,12 @@ ) from torchrl.envs.common import _EnvWrapper -IMPORT_ERR = None -try: + +_has_smacv2 = importlib.util.find_spec("smacv2") is not None +if _has_smacv2: import smacv2 from smacv2.env.starcraft2.maps import smac_maps - _has_smacv2 = True -except ImportError as err: - _has_smacv2 = False - IMPORT_ERR = err - def _get_envs(): if not _has_smacv2: @@ -590,7 +587,7 @@ def __init__( raise ImportError( f"smacv2 python package was not found. Please install this dependency. " f"More info: {self.git_url}." - ) from IMPORT_ERR + ) kwargs["map_name"] = map_name kwargs["capability_config"] = capability_config kwargs["seed"] = seed From 111bbeb7fdcaca0c749e5a0311d7b71b0eaa3f46 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 09:43:28 +0100 Subject: [PATCH 21/35] collector test Signed-off-by: Matteo Bettini --- test/test_libs.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/test_libs.py b/test/test_libs.py index 6383dec83dd..8d4cfefa57d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. import importlib +from torchrl.modules import MaskedCategorical + _has_isaac = importlib.util.find_spec("isaacgym") is not None if _has_isaac: @@ -35,6 +37,11 @@ ) from packaging import version from tensordict import LazyStackedTensorDict +from tensordict.nn import ( + ProbabilisticTensorDictModule, + TensorDictModule, + TensorDictSequential, +) from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import implement_for @@ -1686,6 +1693,30 @@ def test_vec_env(self): check_env_specs(env, seed=None) env.close() + def test_collector(self): + env = SMACv2Env(map_name="MMM2", seed=0, categorical_actions=True) + in_feats = env.observation_spec["agents", "observation"].shape[-1] + out_feats = env.action_spec.space.n + + module = TensorDictModule( + nn.Linear(in_feats, out_feats), + in_keys=[("agents", "observation")], + out_keys=[("agents", "logits")], + ) + prob = ProbabilisticTensorDictModule( + in_keys={"logits": ("agents", "logits"), "mask": ("agents", "action_mask")}, + out_keys=[("agents", "action")], + distribution_class=MaskedCategorical, + ) + actor = TensorDictSequential(module, prob) + + collector = SyncDataCollector( + env, policy=actor, frames_per_batch=20, total_frames=40 + ) + for _ in collector: + break + collector.shutdown() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() From d6daa7ce552ca7424fb9e906744420c2147a695d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 09:49:32 +0100 Subject: [PATCH 22/35] review fixes Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 764c618220b..eb1bfe75fa1 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import importlib +import typing from typing import Dict, Optional import torch @@ -19,14 +20,16 @@ _has_smacv2 = importlib.util.find_spec("smacv2") is not None -if _has_smacv2: + +if typing.TYPE_CHECKING and _has_smacv2: import smacv2 - from smacv2.env.starcraft2.maps import smac_maps def _get_envs(): if not _has_smacv2: return [] + from smacv2.env.starcraft2.maps import smac_maps + return list(smac_maps.get_smac_map_registry().keys()) @@ -187,9 +190,13 @@ def __init__( @property def lib(self): + import smacv2 + return smacv2 def _check_kwargs(self, kwargs: Dict): + import smacv2 + if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] @@ -416,28 +423,26 @@ def get_agent_type(self, agent_index: int) -> str: agent_info = self.agents[agent_index] if agent_info.unit_type == self.marine_id: - agent_type = "marine" + return "marine" elif agent_info.unit_type == self.marauder_id: - agent_type = "marauder" + return "marauder" elif agent_info.unit_type == self.medivac_id: - agent_type = "medivac" + return "medivac" elif agent_info.unit_type == self.hydralisk_id: - agent_type = "hydralisk" + return "hydralisk" elif agent_info.unit_type == self.zergling_id: - agent_type = "zergling" + return "zergling" elif agent_info.unit_type == self.baneling_id: - agent_type = "baneling" + return "baneling" elif agent_info.unit_type == self.stalker_id: - agent_type = "stalker" + return "stalker" elif agent_info.unit_type == self.colossus_id: - agent_type = "colossus" + return "colossus" elif agent_info.unit_type == self.zealot_id: - agent_type = "zealot" + return "zealot" else: raise AssertionError(f"Agent type {agent_info.unit_type} unidentified") - return agent_type - class SMACv2Env(SMACv2Wrapper): """SMACv2 (StarCraft Multi-Agent Challenge v2) environment wrapper. @@ -606,6 +611,7 @@ def _build_env( seed: Optional[int] = None, **kwargs, ) -> "smacv2.env.StarCraft2Env": + import smacv2 if capability_config is not None: env = smacv2.env.StarCraftCapabilityEnvWrapper( From 613b4b69f5cc29a4dbc059fcecfd6dc1aeded084 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 10:18:06 +0100 Subject: [PATCH 23/35] change default categorical actions to true due to absence of one hot masked distribution Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index eb1bfe75fa1..9c9d78cf5c0 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -46,8 +46,8 @@ class SMACv2Wrapper(_EnvWrapper): '27m_vs_30m', 'MMM', 'MMM2', '2s3z', '3s5z', '3s5z_vs_3s6z', '3s_vs_3z', '3s_vs_4z', '3s_vs_5z', '1c3s5z', '2m_vs_1z', 'corridor', '6h_vs_8z', '2s_vs_1sc', 'so_many_baneling', 'bane_vs_bane', '2c_vs_64zg'] >>> # You can use old SMAC maps - >>> env = SMACv2Wrapper(smacv2.env.StarCraft2Env(map_name="MMM2")) - >>> print(env.rollout(5) + >>> env = SMACv2Wrapper(smacv2.env.StarCraft2Env(map_name="MMM2"), categorical_actions=False) + >>> print(env.rollout(5)) TensorDict( fields={ agents: TensorDict( @@ -179,7 +179,7 @@ class SMACv2Wrapper(_EnvWrapper): def __init__( self, env: "smacv2.env.StarCraft2Env" = None, - categorical_actions: bool = False, + categorical_actions: bool = True, **kwargs, ): if env is not None: @@ -528,6 +528,7 @@ class SMACv2Env(SMACv2Wrapper): >>> env = SMACv2Env( ... map_name="10gen_terran", ... capability_config=distribution_config, + ... categorical_actions=False, ... ) >>> print(env.rollout(4)) TensorDict( @@ -585,7 +586,7 @@ def __init__( map_name: str, capability_config: Optional[Dict] = None, seed: Optional[int] = None, - categorical_actions: bool = False, + categorical_actions: bool = True, **kwargs, ): if not _has_smacv2: From d6dd19bd0e88d976e5caf1a2a240441fd9842e6e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 17:48:23 +0100 Subject: [PATCH 24/35] add docs Signed-off-by: Matteo Bettini --- docs/source/reference/envs.rst | 48 ++++++++++++++++++++++++++++++++++ test/test_libs.py | 18 ++++++++----- torchrl/envs/libs/smacv2.py | 19 ++++++++++---- 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 48afa6168ac..034ff1218a2 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -488,6 +488,54 @@ to be able to create this other composition: VIPRewardTransform VIPTransform +Environments with masked actions +-------------------------------- + +In some environments with discrete actions, the actions available to the agent might change throughout execution. +In such cases the environments will output an action mask (under the ``"action_mask"`` key by default). +This mask needs to be used to filter out unavailable actions for that step. + +If you are using a custom policy you can pass this mask to your probability distribution like so: + +.. code-block:: + :caption: Categorical policy with action mask + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.modules import MaskedCategorical + >>> module = TensorDictModule( + >>> nn.Linear(in_feats, out_feats), + >>> in_keys=["observation"], + >>> out_keys=["logits"], + >>> ) + >>> dist = ProbabilisticTensorDictModule( + >>> in_keys={"logits": "logits", "mask": "action_mask"}, + >>> out_keys=["action"], + >>> distribution_class=MaskedCategorical, + >>> ) + >>> actor = TensorDictSequential(module, dist) + +If you want to use a default policy, you will need to wrap your environment in the :class:`~torchrl.envs.transforms.ActionMask` +transform. This transform can take care of updating the action mask in the action spec in order for the default policy +to always know what the latest available actions are. You can do this like so: + +.. code-block:: + :caption: How to use the action mask transform + + >>> from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, TensorDictSequential + >>> import torch.nn as nn + >>> from torchrl.envs.transforms import TransformedEnv, ActionMask + >>> env = TransformedEnv( + >>> your_base_env + >>> ActionMask(action_key="action", mask_key="action_mask"), + >>> ) + +.. note:: + In case you are using a parallel environment it is important to add the transform to the parallel enviornment itself + and not to its sub-environments. + + + Recorders --------- diff --git a/test/test_libs.py b/test/test_libs.py index 8d4cfefa57d..aae72bce953 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import importlib +from envs import ActionMask, TransformedEnv from torchrl.modules import MaskedCategorical _has_isaac = importlib.util.find_spec("isaacgym") is not None @@ -1682,12 +1683,17 @@ def test_env(self, map: str, categorical_actions): check_env_specs(env, seed=None) env.close() - def test_vec_env(self): - env = ParallelEnv( - num_workers=2, - create_env_fn=lambda: SMACv2Env( - map_name="3s_vs_5z", - seed=0, + def test_parallel_env(self): + env = TransformedEnv( + ParallelEnv( + num_workers=2, + create_env_fn=lambda: SMACv2Env( + map_name="3s_vs_5z", + seed=0, + ), + ), + ActionMask( + action_key=("agents", "action"), mask_key=("agents", "action_mask") ), ) check_env_specs(env, seed=None) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 9c9d78cf5c0..4e177804097 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -3,10 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import importlib +import re import typing from typing import Dict, Optional import torch + +from envs.utils import ACTION_MASK_ERROR from tensordict import TensorDict, TensorDictBase from torchrl.data import ( @@ -231,7 +234,7 @@ def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: def _init_env(self) -> None: self._env.reset() - self.update_action_mask() + self._update_action_mask() def _make_action_spec(self) -> CompositeSpec: if self.categorical_actions: @@ -341,7 +344,7 @@ def _reset( state = self._to_tensor(state) info = self.observation_spec["info"].zero() - mask = self.update_action_mask() + mask = self._update_action_mask() # build results agents_td = TensorDict( @@ -361,7 +364,13 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: action_np = self.action_spec.to_numpy(action) # Actions are validated by the environment. - reward, done, info = self._env.step(action_np) + try: + reward, done, info = self._env.step(action_np) + except AssertionError as err: + if re.match(r"Agent . cannot perform action .", str(err)): + raise ACTION_MASK_ERROR + else: + raise err # collect outputs obs = self.get_obs() @@ -377,7 +386,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ).unsqueeze(-1) done = torch.tensor(done, device=self.device, dtype=torch.bool).unsqueeze(-1) - mask = self.update_action_mask() + mask = self._update_action_mask() # build results agents_td = TensorDict( @@ -398,7 +407,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict_out - def update_action_mask(self): + def _update_action_mask(self): mask = torch.tensor( self.get_avail_actions(), dtype=torch.bool, device=self.device ) From b399c40baf3b7c50a4be7757565eb93c3fc4de2f Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 17:58:17 +0100 Subject: [PATCH 25/35] amend Signed-off-by: Matteo Bettini --- torchrl/envs/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index a3a3549695b..eef62b81354 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -44,6 +44,15 @@ DONE_AFTER_RESET_ERROR = RuntimeError( "Env was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) +ACTION_MASK_ERROR = RuntimeError( + "An out-of-bounds actions has been provided to an env with an 'action_mask' output." + " If you are using a custom policy, make sure to take the action mask into account when computing the output." + " If you are using a default policy, please add the torchrl.envs.transforms.ActionMask transform to your environment." + "If you are using a ParallelEnv or another batched inventor, " + "make sure to add the transform to the ParallelEnv (and not to the sub-environments)." + " For more info on using action masks, see the docs at: " + "https://pytorch.org/rl/reference/envs.html#environments-with-masked-actions" +) def _convert_exploration_type(*, exploration_mode, exploration_type): From f43ffed29e6607cd6770aaecb5040b8ed9f6a726 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 5 Sep 2023 20:33:47 +0100 Subject: [PATCH 26/35] amend Signed-off-by: Matteo Bettini --- test/test_libs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index aae72bce953..3c9cb8e1494 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import importlib -from envs import ActionMask, TransformedEnv +from torchrl.envs.transforms import ActionMask, TransformedEnv from torchrl.modules import MaskedCategorical _has_isaac = importlib.util.find_spec("isaacgym") is not None From 64efb6953201e5c30ffd9d76e5b74fad3cfefc4e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 08:50:32 +0100 Subject: [PATCH 27/35] amend Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 4e177804097..7f41ad7ef9b 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -8,8 +8,6 @@ from typing import Dict, Optional import torch - -from envs.utils import ACTION_MASK_ERROR from tensordict import TensorDict, TensorDictBase from torchrl.data import ( @@ -21,6 +19,8 @@ ) from torchrl.envs.common import _EnvWrapper +from torchrl.envs.utils import ACTION_MASK_ERROR + _has_smacv2 = importlib.util.find_spec("smacv2") is not None From 6a56db606084d03e81596a278c93f4e66560ee73 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 14 Sep 2023 16:49:01 +0100 Subject: [PATCH 28/35] ci Signed-off-by: Matteo Bettini --- .../unittest/linux_libs/scripts_pettingzoo/run_test.sh | 2 +- .github/unittest/linux_libs/scripts_smacv2/run_test.sh | 2 +- .github/workflows/test-linux-pettingzoo.yml | 8 ++++---- .github/workflows/test-linux-smacv2.yml | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh index d215b514081..1cdb653ede8 100755 --- a/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh +++ b/.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh @@ -25,6 +25,6 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON # this workflow only tests the libs python -c "import pettingzoo" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestPettingZoo --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestPettingZoo --error-for-skips coverage combine coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_smacv2/run_test.sh b/.github/unittest/linux_libs/scripts_smacv2/run_test.sh index cb7425cf640..65fd7462df3 100755 --- a/.github/unittest/linux_libs/scripts_smacv2/run_test.sh +++ b/.github/unittest/linux_libs/scripts_smacv2/run_test.sh @@ -27,6 +27,6 @@ export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON # this workflow only tests the libs python -c "import smacv2" -python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestSmacv2 --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestSmacv2 --error-for-skips coverage combine coverage xml -i diff --git a/.github/workflows/test-linux-pettingzoo.yml b/.github/workflows/test-linux-pettingzoo.yml index bbf775f4c27..628be74beef 100644 --- a/.github/workflows/test-linux-pettingzoo.yml +++ b/.github/workflows/test-linux-pettingzoo.yml @@ -34,7 +34,7 @@ jobs: nvidia-smi - bash .circleci/unittest/linux_libs/scripts_pettingzoo/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_pettingzoo/install.sh - bash .circleci/unittest/linux_libs/scripts_pettingzoo/run_test.sh - bash .circleci/unittest/linux_libs/scripts_pettingzoo/post_process.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/setup_env.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/install.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/run_test.sh + bash .github/unittest/linux_libs/scripts_pettingzoo/post_process.sh diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml index 35dd4ff2409..13255765693 100644 --- a/.github/workflows/test-linux-smacv2.yml +++ b/.github/workflows/test-linux-smacv2.yml @@ -34,7 +34,7 @@ jobs: nvidia-smi - bash .circleci/unittest/linux_libs/scripts_smacv2/setup_env.sh - bash .circleci/unittest/linux_libs/scripts_smacv2/install.sh - bash .circleci/unittest/linux_libs/scripts_smacv2/run_test.sh - bash .circleci/unittest/linux_libs/scripts_smacv2/post_process.sh + bash .github/unittest/linux_libs/scripts_smacv2/setup_env.sh + bash .github/unittest/linux_libs/scripts_smacv2/install.sh + bash .github/unittest/linux_libs/scripts_smacv2/run_test.sh + bash .github/unittest/linux_libs/scripts_smacv2/post_process.sh From e2e90ca2ea8966b83f70ae445846f4b701b13ebb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 14 Sep 2023 16:55:02 +0100 Subject: [PATCH 29/35] add conditional ci Signed-off-by: Matteo Bettini --- .github/workflows/test-linux-smacv2.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml index 13255765693..db9d191face 100644 --- a/.github/workflows/test-linux-smacv2.yml +++ b/.github/workflows/test-linux-smacv2.yml @@ -17,6 +17,7 @@ concurrency: jobs: unittests: + if: ${{ github.event.label.name == 'Environments' }} uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl From ba1022012b49d540b118244ca21c5a08112871a4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 14 Sep 2023 19:44:45 +0100 Subject: [PATCH 30/35] add conditional ci Signed-off-by: Matteo Bettini --- .github/workflows/test-linux-smacv2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml index db9d191face..652718d17f9 100644 --- a/.github/workflows/test-linux-smacv2.yml +++ b/.github/workflows/test-linux-smacv2.yml @@ -17,7 +17,7 @@ concurrency: jobs: unittests: - if: ${{ github.event.label.name == 'Environments' }} + if: contains(github.event.pull_request.labels.*.name, 'Environments') uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl From 7778cca13fb22cc625c83a60b6ebeadc6ba33384 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 15 Sep 2023 08:45:31 +0100 Subject: [PATCH 31/35] import Signed-off-by: Matteo Bettini --- torchrl/envs/libs/smacv2.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index 7f41ad7ef9b..0c0bda35d28 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import importlib import re -import typing + from typing import Dict, Optional import torch @@ -24,9 +24,6 @@ _has_smacv2 = importlib.util.find_spec("smacv2") is not None -if typing.TYPE_CHECKING and _has_smacv2: - import smacv2 - def _get_envs(): if not _has_smacv2: @@ -181,7 +178,7 @@ class SMACv2Wrapper(_EnvWrapper): def __init__( self, - env: "smacv2.env.StarCraft2Env" = None, + env: "smacv2.env.StarCraft2Env" = None, # noqa: F821 categorical_actions: bool = True, **kwargs, ): @@ -208,7 +205,7 @@ def _check_kwargs(self, kwargs: Dict): def _build_env( self, - env: "smacv2.env.StarCraft2Env", + env: "smacv2.env.StarCraft2Env", # noqa: F821 ): if len(self.batch_size): raise RuntimeError( @@ -217,7 +214,7 @@ def _build_env( return env - def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: + def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821 self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} self.reward_spec = UnboundedContinuousTensorSpec( shape=torch.Size((1,)), @@ -620,7 +617,7 @@ def _build_env( capability_config: Optional[Dict] = None, seed: Optional[int] = None, **kwargs, - ) -> "smacv2.env.StarCraft2Env": + ) -> "smacv2.env.StarCraft2Env": # noqa: F821 import smacv2 if capability_config is not None: From abd828194d208f77d65c7544e9dd32b264c3e737 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 15 Sep 2023 09:03:30 +0100 Subject: [PATCH 32/35] test Signed-off-by: Matteo Bettini --- .github/workflows/test-linux-smacv2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml index 652718d17f9..fff847ad26b 100644 --- a/.github/workflows/test-linux-smacv2.yml +++ b/.github/workflows/test-linux-smacv2.yml @@ -17,7 +17,7 @@ concurrency: jobs: unittests: - if: contains(github.event.pull_request.labels.*.name, 'Environments') + if: ${{ (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'Environments')) || (github.event_name == 'push') }} uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl From a3c2334c04539ab7185e30b184b7f9e61bc2f694 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 15 Sep 2023 09:05:23 +0100 Subject: [PATCH 33/35] test Signed-off-by: Matteo Bettini --- .github/workflows/test-linux-smacv2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml index fff847ad26b..f7e820dcfa6 100644 --- a/.github/workflows/test-linux-smacv2.yml +++ b/.github/workflows/test-linux-smacv2.yml @@ -17,7 +17,7 @@ concurrency: jobs: unittests: - if: ${{ (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'Environments')) || (github.event_name == 'push') }} + if: ${{ contains(github.event.pull_request.labels.*.name, 'Environments') || github.event_name == 'push' }} uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl From b8aa329f70548d22666c7cc2c440b1630f6252cc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 15 Sep 2023 09:05:54 +0100 Subject: [PATCH 34/35] test Signed-off-by: Matteo Bettini --- .github/workflows/test-linux-smacv2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-smacv2.yml b/.github/workflows/test-linux-smacv2.yml index f7e820dcfa6..159c93fb1a1 100644 --- a/.github/workflows/test-linux-smacv2.yml +++ b/.github/workflows/test-linux-smacv2.yml @@ -17,7 +17,7 @@ concurrency: jobs: unittests: - if: ${{ contains(github.event.pull_request.labels.*.name, 'Environments') || github.event_name == 'push' }} + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl From 382d06be5ca829ad3c90c959837a41e7df7abb10 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 15 Sep 2023 09:10:51 +0100 Subject: [PATCH 35/35] empty