From ee77a11b9908483b60abe28e62d516bf38bef7f2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 11 Jul 2023 11:07:22 +0100 Subject: [PATCH 01/27] init Signed-off-by: Matteo Bettini --- torchrl/envs/transforms/transforms.py | 3 +- torchrl/modules/models/__init__.py | 1 + torchrl/modules/models/multiagent.py | 473 ++++++++++++++++++++ torchrl/modules/tensordict_module/common.py | 11 +- torchrl/objectives/__init__.py | 1 + torchrl/objectives/dqn.py | 10 +- torchrl/objectives/multiagent/__init__.py | 6 + torchrl/objectives/multiagent/qmixer.py | 382 ++++++++++++++++ torchrl/objectives/value/advantages.py | 2 +- 9 files changed, 879 insertions(+), 10 deletions(-) create mode 100644 torchrl/modules/models/multiagent.py create mode 100644 torchrl/objectives/multiagent/__init__.py create mode 100644 torchrl/objectives/multiagent/qmixer.py diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ebc913a1214..0f17c37b5f4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3460,6 +3460,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec """Transforms the observation spec, adding the new keys generated by RewardSum.""" # Retrieve parent reward spec reward_spec = self.parent.reward_spec + reward_key = self.parent.reward_key if self.parent else "reward" episode_specs = {} if isinstance(reward_spec, CompositeSpec): @@ -3478,7 +3479,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec else: # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´ - if set(self.in_keys) != {"reward"}: + if set(unravel_key_list(self.in_keys)) != {unravel_key(reward_key)}: raise KeyError( "reward_spec is not a CompositeSpec class, in_keys should only include ´reward´" ) diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 8654d338c18..8e5d0c2f9c9 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -17,4 +17,5 @@ LSTMNet, MLP, ) +from .multiagent import MultiAgentMLP, QMixer, VDNMixer from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py new file mode 100644 index 00000000000..c015bd6910d --- /dev/null +++ b/torchrl/modules/models/multiagent.py @@ -0,0 +1,473 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Tuple, Type, Union + +import numpy as np + +import torch +from torch import nn + +from .models import MLP + + +class MultiAgentMLP(nn.Module): + """Mult-agent MLP. + + This is an MLP that can be used in multi-agent contexts. + For example as a policy or as a value function. + See `examples/multiagent` for examples. + + It expects inputs with shape (*B, n_agents, n_agent_inputs) + It returns outputs with shape (*B, n_agents, n_agent_outputs) + + If `share_params` is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). + Otherwise, each agent will use a different MLP to process its input (heterogeneous policies). + + If `centralised` is True, each agent will use the inputs of all agents to compute its output + (n_agent_inputs * n_agents will be the nu,ber of inputs for one agent). + Otherwise, each agent will only use its data as input. + + **kwargs for :class:`~torchrl.modules.models.MLP` can be passed to customize the MLPs. + + """ + + def __init__( + self, + n_agent_inputs: int, + n_agent_outputs: int, + n_agents: int, + centralised: bool, + share_params: bool, + device: Union[torch.device, str], + depth: Optional[int] = None, + num_cells: Optional[Union[Sequence, int]] = None, + activation_class: Type[nn.Module] = nn.Tanh, + **kwargs, + ): + super().__init__() + + self.n_agents = n_agents + self.n_agent_inputs = n_agent_inputs + self.n_agent_outputs = n_agent_outputs + self.share_params = share_params + self.centralised = centralised + + self.agent_networks = nn.ModuleList( + [ + MLP( + in_features=n_agent_inputs + if not centralised + else n_agent_inputs * n_agents, + out_features=n_agent_outputs, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + device=device, + **kwargs, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + if len(inputs) > 1: + inputs = (torch.cat([*inputs], -1),) + inputs = inputs[0] + + if inputs.shape[-2:] != (self.n_agents, self.n_agent_inputs): + raise ValueError( + f"Multi-agent network expected input with last 2 dimensions {[self.n_agents, self.n_agent_inputs]}," + f" but got {inputs.shape}" + ) + + # If the model is centralized, agents have full observability + if self.centralised: + inputs = inputs.view( + *inputs.shape[:-2], self.n_agents * self.n_agent_inputs + ) + + # If parameters are not shared, each agent has its own network + if not self.share_params: + if self.centralised: + output = torch.stack( + [net(inputs) for i, net in enumerate(self.agent_networks)], + dim=-2, + ) + else: + output = torch.stack( + [ + net(inputs[..., i, :]) + for i, net in enumerate(self.agent_networks) + ], + dim=-2, + ) + # If parameters are shared, agents use the same network + else: + output = self.agent_networks[0](inputs) + + if self.centralised: + # If the parameters are shared, and it is centralised, all agents will have the same output + # We expand it to maintain the agent dimension, but values will be the same for all agents + output = ( + output.view(*output.shape[:-1], self.n_agent_outputs) + .unsqueeze(-2) + .expand(*output.shape[:-1], self.n_agents, self.n_agent_outputs) + ) + + if output.shape[-2:] != (self.n_agents, self.n_agent_outputs): + raise ValueError( + f"Multi-agent network expected output with last 2 dimensions {[self.n_agents, self.n_agent_outputs]}," + f" but got {output.shape}" + ) + + return output + + +class Mixer(nn.Module): + """A multi-agent value mixer. + + It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), + into a global value with shape (*B, 1). + Used with the :class:`~torchrl.objectives.QMixerLoss`. + See `examples/multiagent/qmix_vdn.py` for examples. + + Args: + n_agents (int): number of agents, + device (str or torch.Device): torch device for the network + needs_state (bool): whether the mixer takes a global state as input + state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions) + + Examples: + Creating a VDN mixer + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import VDNMixer + >>> n_agents = 4 + >>> vdn = TensorDictModule( + ... module=VDNMixer( + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents","chosen_action_value")], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents])}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + + + Creating a QMix mixer + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import QMixer + >>> n_agents = 4 + >>> qmix = TensorDictModule( + ... module=QMixer( + ... state_shape=(64, 64, 3), + ... mixing_embed_dim=32, + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents", "chosen_action_value"), "state"], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + """ + + def __init__( + self, + n_agents: int, + device, + needs_state: bool, + state_shape: Union[Tuple[int, ...], torch.Size], + ): + super().__init__() + + self.n_agents = n_agents + self.device = device + self.needs_state = needs_state + self.state_shape = state_shape + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + """Forward pass of the mixer. + + Args: + *inputs: The first input should be the value of the chosen action of shape (*B, self.n_agents, 1), + representing the local q value of each agent. + The second input (optional, used only in some mixers) + is the shared state of all agents of shape (*B, *self.state_shape). + + Returns: + The global value of the chosen actions obtained after mixing, with shape (*B, 1) + + """ + if not self.needs_state: + if len(inputs) > 1: + raise ValueError( + "Mixer that doesn't need state was passed more than 1 input" + ) + chosen_action_value = inputs[0] + else: + if len(inputs) > 2: + raise ValueError("Mixer that needs state was passed more than 2 inputs") + + chosen_action_value, state = inputs + + if state.shape[-len(self.state_shape) :] != self.state_shape: + raise ValueError( + f"Mixer network expected state with ending shape {self.state_shape}," + f" but got state shape {state.shape}" + ) + + if chosen_action_value.shape[-2:] != (self.n_agents, 1): + raise ValueError( + f"Mixer network expected chosen_action_value with last 2 dimensions {(self.n_agents,1)}," + f" but got {chosen_action_value.shape}" + ) + batch_dims = chosen_action_value.shape[:-2] + + if not self.needs_state: + output = self.mix(chosen_action_value, None) + else: + output = self.mix(chosen_action_value, state) + + if output.shape != (*batch_dims, 1): + raise ValueError( + f"Mixer network expected output with same shape as input minus the multi-agent dimension," + f" but got {output.shape}" + ) + + return output + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + """Forward pass for the mixer. + + Args: + chosen_action_value: Tensor of shape [*B, n_agents] + + Returns: + chosen_action_value: Tensor of shape [*B] + """ + raise NotImplementedError + + +class VDNMixer(Mixer): + """Mixer from https://arxiv.org/abs/1706.05296 . + + Examples: + Creating a VDN mixer + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import VDNMixer + >>> n_agents = 4 + >>> vdn = TensorDictModule( + ... module=VDNMixer( + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents","chosen_action_value")], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents])}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + """ + + def __init__( + self, + n_agents: int, + device, + ): + super().__init__( + needs_state=False, + state_shape=torch.Size([]), + n_agents=n_agents, + device=device, + ) + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + return chosen_action_value.sum(dim=-2) + + +class QMixer(Mixer): + """Mixer from https://arxiv.org/abs/1803.11485 . + + Examples: + Creating a QMix mixer + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.multiagent import QMixer + >>> n_agents = 4 + >>> qmix = TensorDictModule( + ... module=QMixer( + ... state_shape=(64, 64, 3), + ... mixing_embed_dim=32, + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents", "chosen_action_value"), "state"], + ... out_keys=["chosen_action_value"], + ... ) + >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32]) + >>> td + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + >>> vdn(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32, 4]), + device=None, + is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([32]), + device=None, + is_shared=False) + """ + + def __init__( + self, + state_shape, + mixing_embed_dim, + n_agents: int, + device, + ): + super().__init__( + needs_state=True, state_shape=state_shape, n_agents=n_agents, device=device + ) + + self.embed_dim = mixing_embed_dim + self.state_dim = int(np.prod(state_shape)) + + self.hyper_w_1 = nn.Linear( + self.state_dim, self.embed_dim * self.n_agents, device=self.device + ) + self.hyper_w_final = nn.Linear( + self.state_dim, self.embed_dim, device=self.device + ) + + # State dependent bias for hidden layer + self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim, device=self.device) + + # V(s) instead of a bias for the last layers + self.V = nn.Sequential( + nn.Linear(self.state_dim, self.embed_dim, device=self.device), + nn.ReLU(), + nn.Linear(self.embed_dim, 1, device=self.device), + ) + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + bs = chosen_action_value.shape[:-2] + state = state.view(-1, self.state_dim) + chosen_action_value = chosen_action_value.view(-1, 1, self.n_agents) + # First layer + w1 = torch.abs(self.hyper_w_1(state)) + b1 = self.hyper_b_1(state) + w1 = w1.view(-1, self.n_agents, self.embed_dim) + b1 = b1.view(-1, 1, self.embed_dim) + hidden = nn.functional.elu( + torch.bmm(chosen_action_value, w1) + b1 + ) # [-1, 1, self.embed_dim] + # Second layer + w_final = torch.abs(self.hyper_w_final(state)) + w_final = w_final.view(-1, self.embed_dim, 1) + # State-dependent bias + v = self.V(state).view(-1, 1, 1) + # Compute final output + y = torch.bmm(hidden, w_final) + v # [-1, 1, 1] + # Reshape and return + q_tot = y.view(*bs, 1) + return q_tot diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 56c185272b9..c5f34a7774d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,7 +9,7 @@ import inspect import re import warnings -from typing import Iterable, Optional, Type, Union +from typing import Iterable, List, Optional, Type, Union import torch @@ -17,6 +17,7 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.tensordict import TensorDictBase +from tensordict.utils import NestedKey from torch import nn @@ -364,12 +365,16 @@ def ensure_tensordict_compatible( module: Union[ FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module ], - in_keys: Optional[Iterable[str]] = None, - out_keys: Optional[Iterable[str]] = None, + in_keys: Optional[List[NestedKey]] = None, + out_keys: Optional[List[NestedKey]] = None, safe: bool = False, wrapper_type: Optional[Type] = TensorDictModule, **kwargs, ): + """Ensures module is compatible with TensorDictModule and, if not, it wraps it.""" + in_keys = unravel_key_list(in_keys) if in_keys else in_keys + out_keys = unravel_key_list(out_keys) if out_keys else out_keys + """Checks and ensures an object with forward method is TensorDict compatible.""" if is_tensordict_compatible(module): if in_keys is not None and set(in_keys) != set(module.in_keys): diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 5755fc2a27c..a1d6bce40e9 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -10,6 +10,7 @@ from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .iql import IQLLoss +from .multiagent import * from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from .redq import REDQLoss from .reinforce import ReinforceLoss diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index c8a1ccbb390..d7407c89c73 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -51,7 +51,7 @@ class DQNLoss(LossModule): :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). If not provided, an attempt to retrieve it from the value network will be made. - priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. @@ -123,10 +123,10 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"advantage"``. value_target (NestedKey): The input tensordict key where the target state value is expected. Will be used for the underlying value estimator Defaults to ``"value_target"``. - value (NestedKey): The input tensordict key where the state value is expected. - Will be used for the underlying value estimator. Defaults to ``"state_value"``. - state_action_value (NestedKey): The input tensordict key where the state action value is expected. - Defaults to ``"state_action_value"``. + value (NestedKey): The input tensordict key where the chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``. + action_value (NestedKey): The input tensordict key where the action value is expected. + Defaults to ``"action_value"``. action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. priority (NestedKey): The input tensordict key where the target priority is written to. diff --git a/torchrl/objectives/multiagent/__init__.py b/torchrl/objectives/multiagent/__init__.py new file mode 100644 index 00000000000..7340cffd841 --- /dev/null +++ b/torchrl/objectives/multiagent/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .qmixer import QMixerLoss diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py new file mode 100644 index 00000000000..4500406a52f --- /dev/null +++ b/torchrl/objectives/multiagent/qmixer.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import warnings +from copy import deepcopy +from dataclasses import dataclass +from typing import Union + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule +from tensordict.utils import NestedKey +from torch import nn + +from torchrl.data.tensor_specs import TensorSpec + +from torchrl.modules import SafeSequential +from torchrl.modules.tensordict_module.actors import QValueActor +from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible + +from torchrl.modules.utils.utils import _find_action_space + +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TDLambdaEstimator +from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator + + +class QMixerLoss(LossModule): + """The QMixer loss class. + + Mixes local agent q values into a global q value accroding to a mixing network and then + uses DQN updates on the global value. + This loss is for multi-agent applications, therefore it expects the 'local_value', 'action_value' and 'action' keys + to have an agent dimension (this is visible in the dafault AcceptedKeys). + This dimension will be mixed by the mixer which will compute a 'global_value' key, used for a DQN objective. + The premade mixers of type :class:`~torchrl.modules.models.multiagent.Mixer` will expect the multi-agent + dimension to be the penultimate one. + + Args: + local_value_network (QValueActor or nn.Module): a local Q value operator. + mixer_network (TensorDictModule or nn.Module): a mixer network mapping the agents' local Q values + and an optional state to the global Q value. + It is suggested to provide a TensorDictModule wrapping a mixer from `torchrl.modules.models.multiagent.Mixer`. + + Keyword Args: + loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + delay_value (bool, optional): whether to duplicate the value network + into a new target value network to + create a double DQN. Default is ``False``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + If not provided, an attempt to retrieve it from the value network + will be made. + priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + The key at which priority is assumed to be stored within TensorDicts added + to this ReplayBuffer. This is to be used when the sampler is of type + :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules import QValueModule, SafeSequential + >>> from torchrl.modules.models.multiagent import QMixer + >>> from torchrl.objectives.multiagent import QMixerLoss + >>> n_agents = 4 + >>> module = TensorDictModule( + ... nn.Linear(10,3), in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")] + ... ) + >>> value_module = QValueModule( + ... action_value_key=("agents", "action_value"), + ... out_keys=[ + ... ("agents", "action"), + ... ("agents", "action_value"), + ... ("agents", "chosen_action_value"), + ... ], + ... action_space="categorical", + ... ) + >>> qnet = SafeSequential(module, value_module) + >>> qmixer = TensorDictModule( + ... module=QMixer( + ... state_shape=(64, 64, 3), + ... mixing_embed_dim=32, + ... n_agents=n_agents, + ... device="cpu", + ... ), + ... in_keys=[("agents", "chosen_action_value"), "state"], + ... out_keys=["chosen_action_value"], + ... ) + >>> loss = QMixerLoss(qnet, qmixer, action_space="categorical") + >>> td = TensorDict( + ... { + ... "agents": TensorDict( + ... {"observation": torch.zeros(32, n_agents, 10)}, [32, n_agents] + ... ), + ... "state": torch.zeros(32, 64, 64, 3), + ... "next": TensorDict( + ... { + ... "agents": TensorDict( + ... {"observation": torch.zeros(32, n_agents, 10)}, [32, n_agents] + ... ), + ... "state": torch.zeros(32, 64, 64, 3), + ... "reward": torch.zeros(32, 1), + ... "done": torch.zeros(32, 1, dtype=torch.bool), + ... }, + ... [32], + ... ), + ... }, + ... [32], + >>> loss(qnet(td)) + TensorDict( + fields={ + loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + advantage (NestedKey): The input tensordict key where the advantage is expected. + Will be used for the underlying value estimator. Defaults to ``"advantage"``. + value_target (NestedKey): The input tensordict key where the target state value is expected. + Will be used for the underlying value estimator Defaults to ``"value_target"``. + local_value (NestedKey): The input tensordict key where the local chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``("agents", "chosen_action_value")``. + global_value (NestedKey): The input tensordict key where the global chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``. + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``("agents", "action")``. + action_value (NestedKey): The input tensordict key where the action value is expected. + Defaults to ``("agents", "action_value")``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + """ + + advantage: NestedKey = "advantage" + value_target: NestedKey = "value_target" + local_value: NestedKey = ("agents", "chosen_action_value") + global_value: NestedKey = "chosen_action_value" + action_value: NestedKey = ("agents", "action_value") + action: NestedKey = ("agents", "action") + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = ["loss"] + + def __init__( + self, + local_value_network: Union[QValueActor, nn.Module], + mixer_network: Union[TensorDictModule, nn.Module], + *, + loss_function: str = "l2", + delay_value: bool = False, + gamma: float = None, + action_space: Union[str, TensorSpec] = None, + priority_key: str = None, + ) -> None: + super().__init__() + self._in_keys = None + self._set_deprecated_ctor_keys(priority=priority_key) + self.delay_value = delay_value + local_value_network = ensure_tensordict_compatible( + module=local_value_network, + wrapper_type=QValueActor, + action_space=action_space, + ) + if not isinstance(mixer_network, TensorDictModule): + # If it is not a TensorDictModule we make it one with default keys + mixer_network = ensure_tensordict_compatible( + module=mixer_network, + in_keys=[self.tensor_keys.local_value], + out_keys=[self.tensor_keys.global_value], + ) + + global_value_network = SafeSequential(local_value_network, mixer_network) + params = make_functional(global_value_network) + self.global_value_network = deepcopy(global_value_network) + repopulate_module(local_value_network, params["module", "0"]) + repopulate_module(mixer_network, params["module", "1"]) + + self.convert_to_functional( + local_value_network, + "local_value_network", + create_target_params=self.delay_value, + ) + self.convert_to_functional( + mixer_network, + "mixer_network", + create_target_params=self.delay_value, + ) + self.global_value_network.module[0] = self.local_value_network + self.global_value_network.module[1] = self.mixer_network + + self.global_value_network_in_keys = global_value_network.in_keys + + self.loss_function = loss_function + if action_space is None: + # infer from value net + try: + action_space = local_value_network.spec + except AttributeError: + # let's try with action_space then + try: + action_space = local_value_network.action_space + except AttributeError: + raise ValueError(self.ACTION_SPEC_ERROR) + if action_space is None: + warnings.warn( + "action_space was not specified. DQNLoss will default to 'one-hot'." + "This behaviour will be deprecated soon and a space will have to be passed." + "Check the DQNLoss documentation to see how to pass the action space. " + ) + action_space = "one-hot" + + self.action_space = _find_action_space(action_space) + + if gamma is not None: + warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) + self.gamma = gamma + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + advantage=self.tensor_keys.advantage, + value_target=self.tensor_keys.value_target, + value=self.tensor_keys.global_value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + ) + self._set_in_keys() + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + *self.global_value_network.in_keys, + *[("next", key) for key in self.global_value_network.in_keys], + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + **hp, value_network=self.global_value_network + ) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + **hp, value_network=self.global_value_network + ) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + **hp, value_network=self.global_value_network + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "advantage": self.tensor_keys.advantage, + "value_target": self.tensor_keys.value_target, + "value": self.tensor_keys.global_value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + } + self._value_estimator.set_keys(**tensor_keys) + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDict: + device = self.device if self.device is not None else tensordict.device + tddevice = tensordict.to(device) + + td_copy = tddevice.clone(False) + self.local_value_network( + td_copy, + params=self.local_value_network_params, + ) + + action = tddevice.get(self.tensor_keys.action) + pred_val = td_copy.get( + self.tensor_keys.action_value + ) # [*B, n_agents, n_actions] + + if self.action_space == "categorical": + if action.shape != pred_val.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + pred_val_index = torch.gather(pred_val, -1, index=action) + else: + action = action.to(torch.float) + pred_val_index = (pred_val * action).sum(-1, keepdim=True) + + td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1] + self.mixer_network(td_copy, params=self.mixer_network_params) + pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1) + # [*B] this is global and shared among the agents as will be the target + + target_value = self.value_estimator.value_estimate( + tddevice.clone(False), + target_params=self._cached_target_params, + ).squeeze( + -1 + ) # [*B] + + priority_tensor = (pred_val_index - target_value).pow(2) + priority_tensor = priority_tensor.detach().unsqueeze(-1) + if tddevice.device is not None: + priority_tensor = priority_tensor.to(tddevice.device) + + tensordict.set( + self.tensor_keys.priority, + priority_tensor, + inplace=True, + ) + loss = distance_loss(pred_val_index, target_value, self.loss_function) + return TensorDict({"loss": loss.mean()}, []) + + @property + @_cache_values + def _cached_target_params(self): + target_params = TensorDict( + { + "module": { + "0": self.target_local_value_network_params, + "1": self.target_mixer_network_params, + } + }, + batch_size=self.target_local_value_network_params.batch_size, + device=self.target_local_value_network_params.device, + ) + return target_params diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index bcb8f68cdaf..335b3c2e4e6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -365,7 +365,7 @@ def is_stateless(self): return self.value_network._is_stateless def _next_value(self, tensordict, target_params, kwargs): - step_td = step_mdp(tensordict) + step_td = step_mdp(tensordict, keep_other=False) if self.value_network is not None: if target_params is not None: kwargs["params"] = target_params From 42e6bb23baa7462bd88912865b2cf8128c5946e0 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 11 Jul 2023 13:50:32 +0100 Subject: [PATCH 02/27] tests Signed-off-by: Matteo Bettini --- test/test_cost.py | 437 +++++++++++++++++++++++- torchrl/objectives/multiagent/qmixer.py | 1 + 2 files changed, 425 insertions(+), 13 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index d087e7073bf..b88db3d50a5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -23,6 +23,8 @@ TensorDictSequential as Seq, ) +from torchrl.modules.models import QMixer + _has_functorch = True try: import functorch as ft # noqa @@ -46,6 +48,7 @@ # from torchrl.data.postprocs.utils import expand_as_right from tensordict.tensordict import assert_allclose_td, TensorDict +from tensordict.utils import unravel_key from torch import autograd, nn from torchrl.data import ( BoundedTensorSpec, @@ -80,6 +83,7 @@ ActorCriticOperator, ActorValueOperator, ProbabilisticActor, + QValueModule, ValueOperator, ) from torchrl.modules.utils import Buffer @@ -97,6 +101,7 @@ IQLLoss, KLPENPPOLoss, PPOLoss, + QMixerLoss, SACLoss, TD3Loss, ) @@ -719,6 +724,420 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): assert loss_fn.tensor_keys.priority in td.keys() +class TestQMixer(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + action_spec_type, + obs_dim=3, + action_dim=4, + device="cpu", + is_nn_module=False, + observation_key=("agents", "observation"), + action_key=("agents", "action"), + action_value_key=("agents", "action_value"), + chosen_action_value_key=("agents", "chosen_action_value"), + ): + # Actor + if action_spec_type == "one_hot": + action_spec = OneHotDiscreteTensorSpec(action_dim) + elif action_spec_type == "categorical": + action_spec = DiscreteTensorSpec(action_dim) + else: + raise ValueError(f"Wrong {action_spec_type}") + + module = nn.Linear(obs_dim, action_dim) + if is_nn_module: + return module.to(device) + module = TensorDictModule( + module, + in_keys=[observation_key], + out_keys=[action_value_key], + ) + value_module = QValueModule( + action_value_key=action_value_key, + out_keys=[ + action_key, + action_value_key, + chosen_action_value_key, + ], + spec=action_spec, + action_space=None, + ) + actor = SafeSequential(module, value_module) + + return actor + + def _create_mock_mixer( + self, + state_shape=(64, 64, 3), + n_agents=4, + device="cpu", + chosen_action_value_key=("agents", "chosen_action_value"), + state_key="state", + global_chosen_action_value_key="chosen_action_value", + ): + qmixer = TensorDictModule( + module=QMixer( + state_shape=state_shape, + mixing_embed_dim=32, + n_agents=n_agents, + device=device, + ), + in_keys=[chosen_action_value_key, state_key], + out_keys=[global_chosen_action_value_key], + ) + + return qmixer + + def _create_mock_data_dqn( + self, + action_spec_type, + batch=(2,), + T=None, + n_agents=4, + obs_dim=3, + state_shape=(64, 64, 3), + action_dim=4, + device="cpu", + action_key=("agents", "action"), + action_value_key=("agents", "action_value"), + ): + if T is not None: + batch = batch + (T,) + # create a tensordict + obs = torch.randn(*batch, n_agents, obs_dim, device=device) + state = torch.randn(*batch, *state_shape, device=device) + next_obs = torch.randn(*batch, n_agents, obs_dim, device=device) + next_state = torch.randn(*batch, *state_shape, device=device) + + action_value = torch.randn(*batch, n_agents, action_dim, device=device) + if action_spec_type == "one_hot": + action = (action_value == action_value.max(dim=-1, keepdim=True)[0]).to( + torch.long + ) + elif action_spec_type == "categorical": + action = torch.argmax(action_value, dim=-1).to(torch.long) + + reward = torch.randn(*batch, 1, device=device) + done = torch.zeros(*batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + { + "agents": TensorDict( + {"observation": obs}, + [*batch, n_agents], + device=device, + ), + "state": state, + "collector": { + "mask": torch.zeros(*batch, dtype=torch.bool, device=device) + }, + "next": TensorDict( + { + "agents": TensorDict( + {"observation": next_obs}, + [*batch, n_agents], + device=device, + ), + "state": next_state, + "reward": reward, + "done": done, + }, + batch_size=batch, + device=device, + ), + }, + batch_size=batch, + device=device, + ) + td.set(action_key, action) + td.set(action_value_key, action_value) + if T is not None: + td.refine_names(None, "time") + return td + + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_qmixer(self, delay_value, device, action_spec_type, td_est): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + mixer = self._create_mock_mixer(device=device) + td = self._create_mock_data_dqn( + action_spec_type=action_spec_type, device=device + ) + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) + if td_est is ValueEstimators.GAE: + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if delay_value + else contextlib.nullcontext() + ), _check_td_steady(td): + loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() + + sum([item for _, item in loss.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = loss_fn.target_local_value_network_params.clone() + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = loss_fn.target_local_value_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # Check param update effect on targets + target_value = loss_fn.target_mixer_network_params.clone() + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = loss_fn.target_mixer_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + mixer = self._create_mock_mixer(device=device) + td = self._create_mock_data_dqn( + action_spec_type=action_spec_type, T=4, device=device + ) + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + ms_td = ms(td.clone()) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if delay_value + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*td.keys(True, True))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = loss_fn.target_local_value_network_params.clone() + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = loss_fn.target_local_value_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # Check param update effect on targets + target_value = loss_fn.target_mixer_network_params.clone() + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = loss_fn.target_mixer_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_qmix_tensordict_keys(self, td_est): + torch.manual_seed(self.seed) + action_spec_type = "one_hot" + actor = self._create_mock_actor(action_spec_type=action_spec_type) + mixer = self._create_mock_mixer() + loss_fn = QMixerLoss(actor, mixer) + + default_keys = { + "advantage": "advantage", + "value_target": "value_target", + "local_value": ("agents", "chosen_action_value"), + "global_value": "chosen_action_value", + "priority": "td_error", + "action_value": ("agents", "action_value"), + "action": ("agents", "action"), + "reward": "reward", + "done": "done", + } + + self.tensordict_keys_test(loss_fn, default_keys=default_keys) + + loss_fn = QMixerLoss(actor, mixer) + key_mapping = { + "advantage": ("advantage", "advantage_2"), + "value_target": ("value_target", ("value_target", "nested")), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + actor = self._create_mock_actor( + action_spec_type=action_spec_type, + ) + mixer = self._create_mock_mixer( + global_chosen_action_value_key=("some", "nested") + ) + loss_fn = QMixerLoss(actor, mixer) + key_mapping = { + "global_value": ("value", ("some", "nested")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_spec_type", ("categorical", "one_hot")) + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_qmix_tensordict_run(self, action_spec_type, td_est): + torch.manual_seed(self.seed) + tensor_keys = { + "action_value": ("other", "action_value_test"), + "action": ("other", "action"), + "local_value": ("some", "local_v"), + "global_value": "global_v", + "priority": "priority_test", + } + actor = self._create_mock_actor( + action_spec_type=action_spec_type, + action_value_key=tensor_keys["action_value"], + action_key=tensor_keys["action"], + chosen_action_value_key=tensor_keys["local_value"], + ) + mixer = self._create_mock_mixer( + chosen_action_value_key=tensor_keys["local_value"], + global_chosen_action_value_key=tensor_keys["global_value"], + ) + td = self._create_mock_data_dqn( + action_spec_type=action_spec_type, + action_key=tensor_keys["action"], + action_value_key=tensor_keys["action_value"], + ) + + loss_fn = QMixerLoss(actor, mixer, loss_function="l2") + loss_fn.set_keys(**tensor_keys) + + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with _check_td_steady(td): + _ = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() + + @pytest.mark.parametrize( + "mixer_local_chosen_action_value_key", + [("agents", "chosen_action_value"), ("other")], + ) + @pytest.mark.parametrize( + "mixer_global_chosen_action_value_key", + ["chosen_action_value", ("nested", "other")], + ) + def test_mixer_keys( + self, + mixer_local_chosen_action_value_key, + mixer_global_chosen_action_value_key, + n_agents=4, + obs_dim=3, + ): + torch.manual_seed(0) + actor = self._create_mock_actor( + action_spec_type="categorical", + ) + mixer = self._create_mock_mixer( + chosen_action_value_key=mixer_local_chosen_action_value_key, + global_chosen_action_value_key=mixer_global_chosen_action_value_key, + n_agents=n_agents, + ) + + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.zeros(32, n_agents, obs_dim)}, [32, n_agents] + ), + "state": torch.zeros(32, 64, 64, 3), + "next": TensorDict( + { + "agents": TensorDict( + {"observation": torch.zeros(32, n_agents, obs_dim)}, + [32, n_agents], + ), + "state": torch.zeros(32, 64, 64, 3), + "reward": torch.zeros(32, 1), + "done": torch.zeros(32, 1, dtype=torch.bool), + }, + [32], + ), + }, + [32], + ) + td = actor(td) + + loss = QMixerLoss(actor, mixer) + + # Wthout etting the keys + if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): + with pytest.raises(RuntimeError): + loss(td) + elif unravel_key(mixer_global_chosen_action_value_key) != "chosen_action_value": + with pytest.raises( + KeyError, match='key "chosen_action_value" not found in TensorDict' + ): + loss(td) + else: + loss(td) + + loss = QMixerLoss(actor, mixer) + # When setting the key + loss.set_keys(global_value=mixer_global_chosen_action_value_key) + if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): + with pytest.raises( + RuntimeError + ): # The mixer in key still does not match the actor out_key + loss(td) + else: + loss(td) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) @@ -2699,7 +3118,6 @@ def test_discrete_sac( target_entropy, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_sac(device=device) @@ -3247,7 +3665,6 @@ def _create_seq_mock_data_redq( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_redq(self, delay_qvalue, num_qvalue, device, td_est): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -3342,7 +3759,6 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): @pytest.mark.parametrize("separate_losses", [False, True]) def test_redq_separate_losses(self, separate_losses): - torch.manual_seed(self.seed) actor, qvalue, common, td = self._create_mock_common_layer_setup() @@ -3431,7 +3847,6 @@ def test_redq_separate_losses(self, separate_losses): @pytest.mark.parametrize("separate_losses", [False, True]) def test_redq_deprecated_separate_losses(self, separate_losses): - torch.manual_seed(self.seed) actor, qvalue, common, td = self._create_mock_common_layer_setup() @@ -3520,7 +3935,6 @@ def test_redq_deprecated_separate_losses(self, separate_losses): @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) def test_redq_shared(self, delay_qvalue, num_qvalue, device): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -3585,7 +3999,6 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): - torch.manual_seed(self.seed) td = self._create_mock_data_redq(device=device) @@ -4111,7 +4524,6 @@ def test_cql_batcher( with_lagrange, device, ): - torch.manual_seed(self.seed) td = self._create_seq_mock_data_cql(device=device) @@ -6454,7 +6866,6 @@ def test_iql( expectile, td_est, ): - torch.manual_seed(self.seed) td = self._create_mock_data_iql(device=device) @@ -7048,7 +7459,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: # total dist d0 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7064,7 +7475,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: for i in range(value_network_update_interval + 1): # test that no update is occuring until value_network_update_interval d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7079,7 +7490,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: assert upd.counter == 0 # test that a new update has occured d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7092,7 +7503,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: elif mode == "soft": upd.step() d1 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) @@ -7105,7 +7516,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: upd.init_() upd.step() d2 = 0.0 - for (key, source_val) in upd._sources.items(True, True): + for key, source_val in upd._sources.items(True, True): if not isinstance(key, tuple): key = (key,) key = ("target_" + key[0], *key[1:]) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 4500406a52f..a88fa07df0d 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -123,6 +123,7 @@ class QMixerLoss(LossModule): ... ), ... }, ... [32], + ... ) >>> loss(qnet(td)) TensorDict( fields={ From 2df23a4725fbe9ea3f1fe1446e0f45814bfd4b51 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 11 Jul 2023 15:13:52 +0100 Subject: [PATCH 03/27] docs Signed-off-by: Matteo Bettini --- docs/source/reference/modules.rst | 14 ++++++++++++++ docs/source/reference/objectives.rst | 15 +++++++++++++++ torchrl/modules/__init__.py | 4 ++++ 3 files changed, 33 insertions(+) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 56b1bfc7fea..0d9e23929a5 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -335,6 +335,20 @@ algorithms, such as DQN, DDPG or Dreamer. RSSMPrior RSSMPosterior +Multi-agent-specific modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +These networks implement models that can be used in +multi-agent contexts. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + MultiAgentMLP + QMixer + VDNMixer + Exploration ----------- diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index ed2d5c3cff7..3325cd05fd6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -185,6 +185,21 @@ Dreamer DreamerModelLoss DreamerValueLoss +Multi-agent objectives +---------------------- +.. currentmodule:: torchrl.objectives.multiagent + +These objectives are specific to multi-agent algorithms. + +QMixer +~~~~~~ + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + QMixerLoss + Returns ------- diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 51235216f4e..ebb73bcedf6 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -25,15 +25,18 @@ DuelingCnnDQNet, LSTMNet, MLP, + MultiAgentMLP, NoisyLazyLinear, NoisyLinear, ObsDecoder, ObsEncoder, + QMixer, reset_noise, RSSMPosterior, RSSMPrior, Squeeze2dLayer, SqueezeLayer, + VDNMixer, ) from .tensordict_module import ( Actor, @@ -58,6 +61,7 @@ SafeSequential, TanhModule, ValueOperator, + VmapModule, WorldModelWrapper, ) from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip From 010bccffdfd24052193419b4aa1010ecbd4b76ee Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 08:26:36 +0100 Subject: [PATCH 04/27] import Signed-off-by: Matteo Bettini --- torchrl/objectives/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index a1d6bce40e9..163365bdc75 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -10,7 +10,7 @@ from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss from .iql import IQLLoss -from .multiagent import * +from .multiagent import QMixerLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from .redq import REDQLoss from .reinforce import ReinforceLoss From e678547dab5d8162a2ae9ed100944540c78070cc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 08:41:48 +0100 Subject: [PATCH 05/27] amend Signed-off-by: Matteo Bettini --- torchrl/objectives/dqn.py | 26 +++++----- torchrl/objectives/multiagent/qmixer.py | 67 ++++++++++++------------- 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index d7407c89c73..9b6a98de48f 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import warnings from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase @@ -40,7 +40,8 @@ class DQNLoss(LossModule): value_network (QValueActor or nn.Module): a Q value operator. Keyword Args: - loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + Defaults to "l2". delay_value (bool, optional): whether to duplicate the value network into a new target value network to create a double DQN. Default is ``False``. @@ -155,13 +156,12 @@ def __init__( self, value_network: Union[QValueActor, nn.Module], *, - loss_function: str = "l2", + loss_function: Optional[str] = "l2", delay_value: bool = False, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, ) -> None: - super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) @@ -282,16 +282,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a tensor containing the DQN loss. """ - device = self.device if self.device is not None else tensordict.device - tddevice = tensordict.to(device) - - td_copy = tddevice.clone(False) + td_copy = tensordict.clone(False) self.value_network( td_copy, params=self.value_network_params, ) - action = tddevice.get(self.tensor_keys.action) + action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get(self.tensor_keys.action_value) if self.action_space == "categorical": @@ -304,13 +301,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: pred_val_index = (pred_val * action).sum(-1) target_value = self.value_estimator.value_estimate( - tddevice.clone(False), target_params=self.target_value_network_params + td_copy, target_params=self.target_value_network_params ).squeeze(-1) - priority_tensor = (pred_val_index - target_value).pow(2) - priority_tensor = priority_tensor.detach().unsqueeze(-1) - if tddevice.device is not None: - priority_tensor = priority_tensor.to(tddevice.device) + with torch.no_grad(): + priority_tensor = (pred_val_index - target_value).pow(2) + priority_tensor = priority_tensor.unsqueeze(-1) + if tensordict.device is not None: + priority_tensor = priority_tensor.to(tensordict.device) tensordict.set( self.tensor_keys.priority, diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index a88fa07df0d..a711888a7af 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -8,7 +8,7 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch from tensordict import TensorDict, TensorDictBase @@ -49,27 +49,28 @@ class QMixerLoss(LossModule): dimension to be the penultimate one. Args: - local_value_network (QValueActor or nn.Module): a local Q value operator. - mixer_network (TensorDictModule or nn.Module): a mixer network mapping the agents' local Q values - and an optional state to the global Q value. - It is suggested to provide a TensorDictModule wrapping a mixer from `torchrl.modules.models.multiagent.Mixer`. + local_value_network (QValueActor or nn.Module): a local Q value operator. + mixer_network (TensorDictModule or nn.Module): a mixer network mapping the agents' local Q values + and an optional state to the global Q value. It is suggested to provide a TensorDictModule + wrapping a mixer from `torchrl.modules.models.multiagent.Mixer`. Keyword Args: - loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". - delay_value (bool, optional): whether to duplicate the value network - into a new target value network to - create a double DQN. Default is ``False``. - action_space (str or TensorSpec, optional): Action space. Must be one of - ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). - If not provided, an attempt to retrieve it from the value network - will be made. - priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] - The key at which priority is assumed to be stored within TensorDicts added - to this ReplayBuffer. This is to be used when the sampler is of type - :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. + loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + Defaults to "l2". + delay_value (bool, optional): whether to duplicate the value network + into a new target value network to + create a double DQN. Default is ``False``. + action_space (str or TensorSpec, optional): Action space. Must be one of + ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, + or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, + :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, + :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + If not provided, an attempt to retrieve it from the value network + will be made. + priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + The key at which priority is assumed to be stored within TensorDicts added + to this ReplayBuffer. This is to be used when the sampler is of type + :class:`~torchrl.data.PrioritizedSampler`. Defaults to ``"td_error"``. Examples: >>> import torch @@ -181,7 +182,7 @@ def __init__( local_value_network: Union[QValueActor, nn.Module], mixer_network: Union[TensorDictModule, nn.Module], *, - loss_function: str = "l2", + loss_function: Optional[str] = "l2", delay_value: bool = False, gamma: float = None, action_space: Union[str, TensorSpec] = None, @@ -319,16 +320,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDict: - device = self.device if self.device is not None else tensordict.device - tddevice = tensordict.to(device) - - td_copy = tddevice.clone(False) + td_copy = tensordict.clone(False) self.local_value_network( td_copy, params=self.local_value_network_params, ) - action = tddevice.get(self.tensor_keys.action) + action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get( self.tensor_keys.action_value ) # [*B, n_agents, n_actions] @@ -348,16 +346,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: # [*B] this is global and shared among the agents as will be the target target_value = self.value_estimator.value_estimate( - tddevice.clone(False), + td_copy, target_params=self._cached_target_params, - ).squeeze( - -1 - ) # [*B] - - priority_tensor = (pred_val_index - target_value).pow(2) - priority_tensor = priority_tensor.detach().unsqueeze(-1) - if tddevice.device is not None: - priority_tensor = priority_tensor.to(tddevice.device) + ).squeeze(-1) + + with torch.no_grad(): + priority_tensor = (pred_val_index - target_value).pow(2) + priority_tensor = priority_tensor.unsqueeze(-1) + if tensordict.device is not None: + priority_tensor = priority_tensor.to(tensordict.device) tensordict.set( self.tensor_keys.priority, From f6047bfb6e4e7dfeb4ff2964b664686b9e653423 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 10:01:07 +0100 Subject: [PATCH 06/27] docs Signed-off-by: Matteo Bettini --- setup.py | 1 - torchrl/modules/models/multiagent.py | 158 ++++++++++++++++++++++++--- 2 files changed, 141 insertions(+), 18 deletions(-) diff --git a/setup.py b/setup.py index d162ee6164e..d8e0820c30f 100644 --- a/setup.py +++ b/setup.py @@ -254,5 +254,4 @@ def _main(argv): if __name__ == "__main__": - _main(sys.argv[1:]) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index c015bd6910d..668e9704348 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -10,6 +10,8 @@ import torch from torch import nn +from ...data import DEVICE_TYPING + from .models import MLP @@ -30,8 +32,112 @@ class MultiAgentMLP(nn.Module): (n_agent_inputs * n_agents will be the nu,ber of inputs for one agent). Otherwise, each agent will only use its data as input. - **kwargs for :class:`~torchrl.modules.models.MLP` can be passed to customize the MLPs. + Args: + n_agent_inputs (int): number of inputs for each agent. + n_agent_outputs (int): number of outputs for each agent. + n_agents (int): number of agents. + centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output + (n_agent_inputs * n_agents will be the nu,ber of inputs for one agent). + Otherwise, each agent will only use its data as input. + share_params (bool): If `share_params` is True, the same MLP will be used to make the forward pass + for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process + its input (heterogeneous policies). + device (str or toech.device, optional): device to create the module on. + depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the + desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, + the depth information should be contained in the num_cells argument (see below). If num_cells is an + iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. + default: 3. + num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers out_features will match the content of num_cells. + default: 32. + activation_class (Type[nn.Module]): activation class to be used. + default: nn.Tanh. + **kwargs: for :class:`~torchrl.modules.models.MLP` can be passed to customize the MLPs. + Examples: + >>> from torchrl.modules import MultiAgentMLP + >>> import torch + >>> n_agents = 6 + >>> n_agent_inputs=3 + >>> n_agent_outputs=2 + >>> batch = 64 + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) + + First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralised=False, + ... share_params=True, + ... depth=2, + ... ) + >>> print(mlp) + MultiAgentMLP( + (agent_networks): ModuleList( + (0): MLP( + (0): Linear(in_features=3, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=2, bias=True) + ) + ) + ) + >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) + + Now let's instantiate a centralised network shared by all agents (e.g. a centalised value function) + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralised=True, + ... share_params=True, + ... depth=2, + ... ) + >>> print(mlp) + MultiAgentMLP( + (agent_networks): ModuleList( + (0): MLP( + (0): Linear(in_features=18, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=2, bias=True) + ) + ) + ) + We can see that the input to the first layer is n_agents * n_agent_inputs, + this is because in the case the net acts as a centralised mlp (like a single huge agent) + >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) + Outputs will be identical for all agents + + Now we can do both examples just shown but with an independent set of parameters for each agent + Let's show the centralised=False case. + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralised=False, + ... share_params=False, + ... depth=2, + ... ) + >>> print(mlp) + MultiAgentMLP( + (agent_networks): ModuleList( + (0-5): 6 x MLP( + (0): Linear(in_features=3, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=2, bias=True) + ) + ) + ) + We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent! + >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) """ def __init__( @@ -41,10 +147,10 @@ def __init__( n_agents: int, centralised: bool, share_params: bool, - device: Union[torch.device, str], + device: Optional[DEVICE_TYPING] = None, depth: Optional[int] = None, num_cells: Optional[Union[Sequence, int]] = None, - activation_class: Type[nn.Module] = nn.Tanh, + activation_class: Optional[Type[nn.Module]] = nn.Tanh, **kwargs, ): super().__init__() @@ -74,8 +180,9 @@ def __init__( def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: - inputs = (torch.cat([*inputs], -1),) - inputs = inputs[0] + inputs = torch.cat([*inputs], -1) + else: + inputs = inputs[0] if inputs.shape[-2:] != (self.n_agents, self.n_agent_inputs): raise ValueError( @@ -136,9 +243,9 @@ class Mixer(nn.Module): Args: n_agents (int): number of agents, - device (str or torch.Device): torch device for the network needs_state (bool): whether the mixer takes a global state as input state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions) + device (str or torch.Device): torch device for the network Examples: Creating a VDN mixer @@ -232,9 +339,9 @@ class Mixer(nn.Module): def __init__( self, n_agents: int, - device, needs_state: bool, state_shape: Union[Tuple[int, ...], torch.Size], + device: DEVICE_TYPING, ): super().__init__() @@ -248,9 +355,9 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: Args: *inputs: The first input should be the value of the chosen action of shape (*B, self.n_agents, 1), - representing the local q value of each agent. - The second input (optional, used only in some mixers) - is the shared state of all agents of shape (*B, *self.state_shape). + representing the local q value of each agent. + The second input (optional, used only in some mixers) + is the shared state of all agents of shape (*B, *self.state_shape). Returns: The global value of the chosen actions obtained after mixing, with shape (*B, 1) @@ -307,11 +414,18 @@ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): class VDNMixer(Mixer): - """Mixer from https://arxiv.org/abs/1706.05296 . + """Value-Decomposition Network mixer. + + Mixes the local Q values of the agents into a global Q value by summing them together. + From the paper https://arxiv.org/abs/1706.05296 . + + Args: + n_agents (int): number of agents, + device (str or torch.Device): torch device for the network Examples: Creating a VDN mixer - >>> import torch + >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.models.multiagent import VDNMixer @@ -355,7 +469,7 @@ class VDNMixer(Mixer): def __init__( self, n_agents: int, - device, + device: DEVICE_TYPING, ): super().__init__( needs_state=False, @@ -369,7 +483,17 @@ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): class QMixer(Mixer): - """Mixer from https://arxiv.org/abs/1803.11485 . + """QMix mixer. + + Mixes the local Q values of the agents into a global Q value through a monotonic + hyper-network whose parameters are obtained from a global state. + From the paper https://arxiv.org/abs/1803.11485 . + + Args + n_agents (int): number of agents + mixing_embed_dim (int): the size of the mixing embedded dimension + state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions) + device (str or torch.Device): torch device for the network Examples: Creating a QMix mixer @@ -420,10 +544,10 @@ class QMixer(Mixer): def __init__( self, - state_shape, - mixing_embed_dim, + state_shape: Union[Tuple[int, ...], torch.Size], + mixing_embed_dim: int, n_agents: int, - device, + device: DEVICE_TYPING, ): super().__init__( needs_state=True, state_shape=state_shape, n_agents=n_agents, device=device From b0b4d2210a03f0cd931dc661f9ddbca09ed9255d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 10:04:04 +0100 Subject: [PATCH 07/27] docs Signed-off-by: Matteo Bettini --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index d8e0820c30f..3723c1b1981 100644 --- a/setup.py +++ b/setup.py @@ -235,6 +235,7 @@ def _main(argv): "checkpointing": [ "torchsnapshot", ], + "marl": ["vmas"], }, zip_safe=False, classifiers=[ From 073a44d1c231eff35ae880ebd81d9fb42a1ce711 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 11:02:12 +0100 Subject: [PATCH 08/27] tests mlp Signed-off-by: Matteo Bettini --- test/test_modules.py | 98 +++++++++++++++++++++++++++- torchrl/modules/models/multiagent.py | 2 +- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 2481ec09f69..8b8bb33b685 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -14,7 +14,14 @@ from tensordict import TensorDict from torch import nn from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec -from torchrl.modules import CEMPlanner, LSTMNet, SafeModule, TanhModule, ValueOperator +from torchrl.modules import ( + CEMPlanner, + LSTMNet, + MultiAgentMLP, + SafeModule, + TanhModule, + ValueOperator, +) from torchrl.modules.distributions.utils import safeatanh, safetanh from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear from torchrl.modules.models.model_based import ( @@ -200,7 +207,6 @@ def test_lstm_net( has_precond_hidden, double_prec_fixture, ): - torch.manual_seed(0) batch = 5 time_steps = 6 @@ -708,6 +714,94 @@ def test_multi_inputs(self, out_keys, has_spec): assert (data[out_key] >= min - eps).all() +class TestMultiAgent: + def _get_mock_input_td( + self, n_agents, n_agents_inputs, state_shape=(64, 64, 3), T=None, batch=(2,) + ): + if T is not None: + batch = batch + (T,) + obs = torch.randn(*batch, n_agents, n_agents_inputs) + state = torch.randn(*batch, *state_shape) + + td = TensorDict( + { + "agents": TensorDict( + {"observation": obs}, + [*batch, n_agents], + ), + "state": state, + }, + batch_size=batch, + ) + return td + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralised", [True, False]) + @pytest.mark.parametrize( + "batch", + [ + (10,), + ( + 10, + 3, + ), + tuple, + ], + ) + def test_mlp( + self, + n_agents, + centralised, + share_params, + batch, + n_agent_inputs=6, + n_agent_outputs=2, + ): + torch.manual_seed(0) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + depth=2, + ) + td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) + obs = td.get(("agents", "observation")) + + out = mlp(obs) + assert out.shape == (*batch, n_agents, n_agent_outputs) + for i in range(n_agents): + if centralised and share_params: + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + obs[..., 0, 0] += 1 + out2 = mlp(obs) + for i in range(n_agents): + if centralised: + # a modification to the input of agent 0 will impact all agents + assert not torch.allclose(out[..., i, :], out2[..., i, :]) + elif i > 0: + assert torch.allclose(out[..., i, :], out2[..., i, :]) + + obs = torch.randn(*batch, 1, n_agent_inputs).expand( + *batch, n_agents, n_agent_inputs + ) + out = mlp(obs) + for i in range(n_agents): + if share_params: + # same input same output + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + # same input different output + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + @pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required") @pytest.mark.parametrize("use_vmap", [False, True]) @pytest.mark.parametrize("scale", range(10)) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 668e9704348..c480cbff196 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -192,7 +192,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: # If the model is centralized, agents have full observability if self.centralised: - inputs = inputs.view( + inputs = inputs.reshape( *inputs.shape[:-2], self.n_agents * self.n_agent_inputs ) From fb4c059eae4be4be72b2c5d0a7601249af4b34dc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 11:03:26 +0100 Subject: [PATCH 09/27] tests mlp Signed-off-by: Matteo Bettini --- test/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_modules.py b/test/test_modules.py index 8b8bb33b685..cb1c5ebdd27 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -746,7 +746,7 @@ def _get_mock_input_td( 10, 3, ), - tuple, + (), ], ) def test_mlp( From d8b28c5ce8c3fabd93b0a3ae4b9be30396fea8ee Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 11:31:00 +0100 Subject: [PATCH 10/27] tests mixers Signed-off-by: Matteo Bettini --- test/test_modules.py | 121 +++++++++++++++++++++++++++ torchrl/modules/models/multiagent.py | 2 +- 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/test/test_modules.py b/test/test_modules.py index cb1c5ebdd27..caa4cca1c9b 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -18,9 +18,11 @@ CEMPlanner, LSTMNet, MultiAgentMLP, + QMixer, SafeModule, TanhModule, ValueOperator, + VDNMixer, ) from torchrl.modules.distributions.utils import safeatanh, safetanh from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear @@ -801,6 +803,125 @@ def test_mlp( # same input different output assert not torch.allclose(out[..., i, :], out[..., j, :]) + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize( + "batch", + [ + (10,), + ( + 10, + 3, + ), + (), + ], + ) + def test_vdn(self, n_agents, batch): + torch.manual_seed(0) + mixer = VDNMixer(n_agents=n_agents, device="cpu") + + td = self._get_mock_input_td(n_agents, batch=batch, n_agents_inputs=1) + obs = td.get(("agents", "observation")) + assert obs.shape == (*batch, n_agents, 1) + out = mixer(obs) + assert out.shape == (*batch, 1) + assert torch.equal(obs.sum(-2), out) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize( + "batch", + [ + (10,), + ( + 10, + 3, + ), + (), + ], + ) + @pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)]) + def test_qmix(self, n_agents, batch, state_shape): + torch.manual_seed(0) + mixer = QMixer( + n_agents=n_agents, + state_shape=state_shape, + mixing_embed_dim=32, + device="cpu", + ) + + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape + ) + obs = td.get(("agents", "observation")) + state = td.get("state") + assert obs.shape == (*batch, n_agents, 1) + assert state.shape == (*batch, *state_shape) + out = mixer(obs, state) + assert out.shape == (*batch, 1) + + @pytest.mark.parametrize("mixer", ["qmix", "vdn"]) + def test_mixer_malformed_input( + self, mixer, n_agents=3, batch=(32,), state_shape=(64, 64, 3) + ): + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=3, state_shape=state_shape + ) + if mixer == "qmix": + mixer = QMixer( + n_agents=n_agents, + state_shape=state_shape, + mixing_embed_dim=32, + device="cpu", + ) + else: + mixer = VDNMixer(n_agents=n_agents, device="cpu") + obs = td.get(("agents", "observation")) + state = td.get("state") + + if mixer.needs_state: + with pytest.raises( + ValueError, + match="Mixer that needs state was passed more than 2 inputs", + ): + mixer(obs) + else: + with pytest.raises( + ValueError, + match="Mixer that doesn't need state was passed more than 1 input", + ): + mixer(obs, state) + + in_put = [obs, state] if mixer.needs_state else [obs] + with pytest.raises( + ValueError, + match="Mixer network expected chosen_action_value with last 2 dimensions", + ): + mixer(*in_put) + if mixer.needs_state: + state_diff = state.unsqueeze(-1) + with pytest.raises( + ValueError, + match="Mixer network expected state with ending shape", + ): + mixer(obs, state_diff) + + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape + ) + obs = td.get(("agents", "observation")) + state = td.get("state") + obs = obs.sum(-2) + in_put = [obs, state] if mixer.needs_state else [obs] + with pytest.raises( + ValueError, + match="Mixer network expected chosen_action_value with last 2 dimensions", + ): + mixer(*in_put) + + obs = td.get(("agents", "observation")) + state = td.get("state") + in_put = [obs, state] if mixer.needs_state else [obs] + mixer(*in_put) + @pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required") @pytest.mark.parametrize("use_vmap", [False, True]) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index c480cbff196..8df7e880525 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -370,7 +370,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: ) chosen_action_value = inputs[0] else: - if len(inputs) > 2: + if len(inputs) != 2: raise ValueError("Mixer that needs state was passed more than 2 inputs") chosen_action_value, state = inputs From 9a7b21629bd85ab449cc2f4350b930bdbc462143 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 12 Jul 2023 11:49:00 +0100 Subject: [PATCH 11/27] device dqn Signed-off-by: Matteo Bettini --- torchrl/objectives/dqn.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 9b6a98de48f..05e8ac0913f 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -282,13 +282,23 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a tensor containing the DQN loss. """ - td_copy = tensordict.clone(False) + if self.device is not None: + warnings.warn( + "The use of a device for the objective function will soon be deprecated", + category=DeprecationWarning, + ) + device = self.device + else: + device = tensordict.device + tddevice = tensordict.to(device) + + td_copy = tddevice.clone(False) self.value_network( td_copy, params=self.value_network_params, ) - action = tensordict.get(self.tensor_keys.action) + action = tddevice.get(self.tensor_keys.action) pred_val = td_copy.get(self.tensor_keys.action_value) if self.action_space == "categorical": From ed7e5c5c6499bd582f71eb4f5ce4fe0aa854e000 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 13 Jul 2023 09:05:22 +0100 Subject: [PATCH 12/27] remove setter Signed-off-by: Matteo Bettini --- torchrl/objectives/dqn.py | 4 ---- torchrl/objectives/multiagent/qmixer.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 05e8ac0913f..d740d45507e 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -231,10 +231,6 @@ def in_keys(self): self._set_in_keys() return self._in_keys - @in_keys.setter - def in_keys(self, values): - self._in_keys = values - def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index a711888a7af..e3f4154748b 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -278,10 +278,6 @@ def in_keys(self): self._set_in_keys() return self._in_keys - @in_keys.setter - def in_keys(self, values): - self._in_keys = values - def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator From 01f57a73f7a20d2a17ce403ca389d13186cd2c7e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 13 Jul 2023 16:52:39 +0100 Subject: [PATCH 13/27] fix test gpu Signed-off-by: Matteo Bettini --- test/test_cost.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 4cc1019c424..29c69ddb670 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -733,7 +733,6 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", - is_nn_module=False, observation_key=("agents", "observation"), action_key=("agents", "action"), action_value_key=("agents", "action_value"), @@ -747,9 +746,8 @@ def _create_mock_actor( else: raise ValueError(f"Wrong {action_spec_type}") - module = nn.Linear(obs_dim, action_dim) - if is_nn_module: - return module.to(device) + module = nn.Linear(obs_dim, action_dim).to(device) + module = TensorDictModule( module, in_keys=[observation_key], From 3f6a35d044cd45e2e49aabb176ee771538f2762c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:34:33 +0100 Subject: [PATCH 14/27] fix Signed-off-by: Matteo Bettini --- test/test_cost.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 29c69ddb670..fd48e8d4766 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -752,7 +752,7 @@ def _create_mock_actor( module, in_keys=[observation_key], out_keys=[action_value_key], - ) + ).to(device) value_module = QValueModule( action_value_key=action_value_key, out_keys=[ @@ -762,7 +762,7 @@ def _create_mock_actor( ], spec=action_spec, action_space=None, - ) + ).to(device) actor = SafeSequential(module, value_module) return actor @@ -785,7 +785,7 @@ def _create_mock_mixer( ), in_keys=[chosen_action_value_key, state_key], out_keys=[global_chosen_action_value_key], - ) + ).to(device) return qmixer From 3931f7f9a70da2bbf7dc2e0d4d5c715f4a79dad8 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:44:21 +0100 Subject: [PATCH 15/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_cost.py b/test/test_cost.py index fd48e8d4766..531800e1a5d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -971,6 +971,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: + print(target_value, "\n\n", target_value2) assert not (target_value == target_value2).any() # check that policy is updated after parameter update From 41129ded122182d603cc295494e69fe905d737d0 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:48:55 +0100 Subject: [PATCH 16/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 531800e1a5d..631be50a4ce 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -966,7 +966,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + p.data += torch.rand_like(p) target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) From b8fd015e11dc9fe1a308df5374688a79d380da99 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:49:23 +0100 Subject: [PATCH 17/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 631be50a4ce..1963e03d519 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -966,7 +966,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.rand_like(p) + p.data += 3 target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) From 66db371820269171eead03a693add3315f98c5ed Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:51:35 +0100 Subject: [PATCH 18/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 1963e03d519..9849fb09801 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -889,7 +889,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + p.data += 3 target_value2 = loss_fn.target_local_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -899,7 +899,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + p.data += 3 target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -956,7 +956,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + p.data += 3 target_value2 = loss_fn.target_local_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -971,7 +971,6 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: - print(target_value, "\n\n", target_value2) assert not (target_value == target_value2).any() # check that policy is updated after parameter update From eb2e441e6b38c0fbb1f68507055b639c2eca0b03 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:54:01 +0100 Subject: [PATCH 19/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 9849fb09801..eaacd44d1c3 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -899,11 +899,13 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += 3 + p.data += torch.randn_like(p) target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: + for key, value in target_value.items(): + print(value == target_value2[key]) assert not (target_value == target_value2).any() # check that policy is updated after parameter update From c61bd079345cb4ba3459c1f160d294585422779e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:55:39 +0100 Subject: [PATCH 20/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index eaacd44d1c3..a964b0c1d52 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -889,7 +889,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): - p.data += 3 + p.data += torch.randn_like(p) target_value2 = loss_fn.target_local_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) From b94d23328d2a5bd601341687b3f7287232a933f5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 08:58:35 +0100 Subject: [PATCH 21/27] temp Signed-off-by: Matteo Bettini --- test/test_cost.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index a964b0c1d52..048455d5644 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -904,8 +904,8 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: - for key, value in target_value.items(): - print(value == target_value2[key]) + for key in target_value.keys(True, True): + print(target_value[key] == target_value2[key]) assert not (target_value == target_value2).any() # check that policy is updated after parameter update From e4705dbeb553cdafae000e542a91e29c4b65aeca Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 09:00:24 +0100 Subject: [PATCH 22/27] fix Signed-off-by: Matteo Bettini --- test/test_cost.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 048455d5644..9849fb09801 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -889,7 +889,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + p.data += 3 target_value2 = loss_fn.target_local_value_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) @@ -899,13 +899,11 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): # Check param update effect on targets target_value = loss_fn.target_mixer_network_params.clone() for p in loss_fn.parameters(): - p.data += torch.randn_like(p) + p.data += 3 target_value2 = loss_fn.target_mixer_network_params.clone() if loss_fn.delay_value: assert_allclose_td(target_value, target_value2) else: - for key in target_value.keys(True, True): - print(target_value[key] == target_value2[key]) assert not (target_value == target_value2).any() # check that policy is updated after parameter update From 9b53b934c00755c6964d00aa8501fcaf3c7df816 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 10:08:49 +0100 Subject: [PATCH 23/27] amend Signed-off-by: Matteo Bettini --- torchrl/objectives/multiagent/qmixer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index e3f4154748b..4fa124b0b8b 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -40,9 +40,10 @@ class QMixerLoss(LossModule): """The QMixer loss class. - Mixes local agent q values into a global q value accroding to a mixing network and then + Mixes local agent q values into a global q value according to a mixing network and then uses DQN updates on the global value. - This loss is for multi-agent applications, therefore it expects the 'local_value', 'action_value' and 'action' keys + This loss is for multi-agent applications. + Therefore, it expects the 'local_value', 'action_value' and 'action' keys to have an agent dimension (this is visible in the dafault AcceptedKeys). This dimension will be mixed by the mixer which will compute a 'global_value' key, used for a DQN objective. The premade mixers of type :class:`~torchrl.modules.models.multiagent.Mixer` will expect the multi-agent @@ -52,7 +53,7 @@ class QMixerLoss(LossModule): local_value_network (QValueActor or nn.Module): a local Q value operator. mixer_network (TensorDictModule or nn.Module): a mixer network mapping the agents' local Q values and an optional state to the global Q value. It is suggested to provide a TensorDictModule - wrapping a mixer from `torchrl.modules.models.multiagent.Mixer`. + wrapping a mixer from :class:`~torchrl.modules.models.multiagent.Mixer`. Keyword Args: loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". @@ -239,9 +240,9 @@ def __init__( raise ValueError(self.ACTION_SPEC_ERROR) if action_space is None: warnings.warn( - "action_space was not specified. DQNLoss will default to 'one-hot'." + "action_space was not specified. QMixerLoss will default to 'one-hot'." "This behaviour will be deprecated soon and a space will have to be passed." - "Check the DQNLoss documentation to see how to pass the action space. " + "Check the QMixerLoss documentation to see how to pass the action space. " ) action_space = "one-hot" From addc6115cf965909710dafe3c28b082e67ca5698 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 10:11:49 +0100 Subject: [PATCH 24/27] amend Signed-off-by: Matteo Bettini --- torchrl/objectives/multiagent/qmixer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 4fa124b0b8b..e9eca7ce293 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -46,14 +46,14 @@ class QMixerLoss(LossModule): Therefore, it expects the 'local_value', 'action_value' and 'action' keys to have an agent dimension (this is visible in the dafault AcceptedKeys). This dimension will be mixed by the mixer which will compute a 'global_value' key, used for a DQN objective. - The premade mixers of type :class:`~torchrl.modules.models.multiagent.Mixer` will expect the multi-agent + The premade mixers of type :class:`torchrl.modules.models.multiagent.Mixer` will expect the multi-agent dimension to be the penultimate one. Args: local_value_network (QValueActor or nn.Module): a local Q value operator. mixer_network (TensorDictModule or nn.Module): a mixer network mapping the agents' local Q values and an optional state to the global Q value. It is suggested to provide a TensorDictModule - wrapping a mixer from :class:`~torchrl.modules.models.multiagent.Mixer`. + wrapping a mixer from :class:`torchrl.modules.models.multiagent.Mixer`. Keyword Args: loss_function (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". From 137aa0d44e5576d5b3e5f594b359e4ee40daeb3c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 10:14:48 +0100 Subject: [PATCH 25/27] amend Signed-off-by: Matteo Bettini --- torchrl/modules/models/multiagent.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 8df7e880525..9385bd5f32e 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -63,8 +63,7 @@ class MultiAgentMLP(nn.Module): >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 - >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) - + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, @@ -87,7 +86,6 @@ class MultiAgentMLP(nn.Module): ) ) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) - Now let's instantiate a centralised network shared by all agents (e.g. a centalised value function) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, @@ -112,8 +110,7 @@ class MultiAgentMLP(nn.Module): We can see that the input to the first layer is n_agents * n_agent_inputs, this is because in the case the net acts as a centralised mlp (like a single huge agent) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) - Outputs will be identical for all agents - + Outputs will be identical for all agents. Now we can do both examples just shown but with an independent set of parameters for each agent Let's show the centralised=False case. >>> mlp = MultiAgentMLP( @@ -288,8 +285,6 @@ class Mixer(nn.Module): batch_size=torch.Size([32]), device=None, is_shared=False) - - Creating a QMix mixer >>> import torch >>> from tensordict import TensorDict @@ -424,7 +419,6 @@ class VDNMixer(Mixer): device (str or torch.Device): torch device for the network Examples: - Creating a VDN mixer >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule @@ -496,7 +490,6 @@ class QMixer(Mixer): device (str or torch.Device): torch device for the network Examples: - Creating a QMix mixer >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule From b15a517a0c08164f72e24df70682d4c87f7a0f1a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 10:24:54 +0100 Subject: [PATCH 26/27] docs Signed-off-by: Matteo Bettini --- torchrl/modules/models/multiagent.py | 40 +++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 9385bd5f32e..43c41980aaf 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -19,7 +19,7 @@ class MultiAgentMLP(nn.Module): """Mult-agent MLP. This is an MLP that can be used in multi-agent contexts. - For example as a policy or as a value function. + For example, as a policy or as a value function. See `examples/multiagent` for examples. It expects inputs with shape (*B, n_agents, n_agent_inputs) @@ -29,7 +29,7 @@ class MultiAgentMLP(nn.Module): Otherwise, each agent will use a different MLP to process its input (heterogeneous policies). If `centralised` is True, each agent will use the inputs of all agents to compute its output - (n_agent_inputs * n_agents will be the nu,ber of inputs for one agent). + (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input. Args: @@ -37,7 +37,7 @@ class MultiAgentMLP(nn.Module): n_agent_outputs (int): number of outputs for each agent. n_agents (int): number of agents. centralised (bool): If `centralised` is True, each agent will use the inputs of all agents to compute its output - (n_agent_inputs * n_agents will be the nu,ber of inputs for one agent). + (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input. share_params (bool): If `share_params` is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process @@ -54,7 +54,7 @@ class MultiAgentMLP(nn.Module): default: 32. activation_class (Type[nn.Module]): activation class to be used. default: nn.Tanh. - **kwargs: for :class:`~torchrl.modules.models.MLP` can be passed to customize the MLPs. + **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. Examples: >>> from torchrl.modules import MultiAgentMLP @@ -235,14 +235,14 @@ class Mixer(nn.Module): It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), into a global value with shape (*B, 1). - Used with the :class:`~torchrl.objectives.QMixerLoss`. + Used with the :class:`torchrl.objectives.QMixerLoss`. See `examples/multiagent/qmix_vdn.py` for examples. Args: - n_agents (int): number of agents, - needs_state (bool): whether the mixer takes a global state as input - state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions) - device (str or torch.Device): torch device for the network + n_agents (int): number of agents. + needs_state (bool): whether the mixer takes a global state as input. + state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions). + device (str or torch.Device): torch device for the network. Examples: Creating a VDN mixer @@ -414,9 +414,14 @@ class VDNMixer(Mixer): Mixes the local Q values of the agents into a global Q value by summing them together. From the paper https://arxiv.org/abs/1706.05296 . + It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), + into a global value with shape (*B, 1). + Used with the :class:`torchrl.objectives.QMixerLoss`. + See `examples/multiagent/qmix_vdn.py` for examples. + Args: - n_agents (int): number of agents, - device (str or torch.Device): torch device for the network + n_agents (int): number of agents. + device (str or torch.Device): torch device for the network. Examples: >>> import torch @@ -483,11 +488,16 @@ class QMixer(Mixer): hyper-network whose parameters are obtained from a global state. From the paper https://arxiv.org/abs/1803.11485 . + It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1), + into a global value with shape (*B, 1). + Used with the :class:`torchrl.objectives.QMixerLoss`. + See `examples/multiagent/qmix_vdn.py` for examples. + Args - n_agents (int): number of agents - mixing_embed_dim (int): the size of the mixing embedded dimension - state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions) - device (str or torch.Device): torch device for the network + n_agents (int): number of agents. + mixing_embed_dim (int): the size of the mixing embedded dimension. + state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions). + device (str or torch.Device): torch device for the network. Examples: >>> import torch From fd2147867b9c92b80e8bbe047e5f278ec05db37a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 14 Jul 2023 10:28:01 +0100 Subject: [PATCH 27/27] docs Signed-off-by: Matteo Bettini --- torchrl/modules/models/multiagent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 43c41980aaf..de565b336d2 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -493,10 +493,10 @@ class QMixer(Mixer): Used with the :class:`torchrl.objectives.QMixerLoss`. See `examples/multiagent/qmix_vdn.py` for examples. - Args - n_agents (int): number of agents. - mixing_embed_dim (int): the size of the mixing embedded dimension. + Args: state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions). + mixing_embed_dim (int): the size of the mixing embedded dimension. + n_agents (int): number of agents. device (str or torch.Device): torch device for the network. Examples: