diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 6d5107fcc64..9ce812019b6 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import NestedKey +from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, @@ -1290,3 +1290,175 @@ def _step( device=self.device, ) return tensordict.select().set("next", tensordict) + + +class HeteroCountingEnv(EnvBase): + """A heterogeneous, counting Env.""" + + def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): + super().__init__(**kwargs) + self.n_agents = 3 + self.max_steps = max_steps + self.start_val = start_val + + count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int) + count[:] = self.start_val + + self.register_buffer("count", count) + + agent_obs_specs = [] + agent_action_specs = [] + for angent_id in range(self.n_agents): + agent_obs_specs.append(self.get_agent_obs_spec(angent_id)) + agent_action_specs.append(self.get_agent_action_spec(angent_id)) + agent_obs_specs = torch.stack(agent_obs_specs, dim=0) + agent_action_specs = torch.stack(agent_action_specs, dim=0) + + self.unbatched_observation_spec = CompositeSpec( + agents=agent_obs_specs, + state=UnboundedContinuousTensorSpec( + shape=( + 64, + 64, + 3, + ) + ), + ) + + self.unbatched_action_spec = CompositeSpec( + agents=agent_action_specs, + ) + self.unbatched_reward_spec = CompositeSpec( + { + "agents": CompositeSpec( + {"reward": UnboundedContinuousTensorSpec(shape=(self.n_agents, 1))}, + shape=(self.n_agents,), + ) + } + ) + self.unbatched_done_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "done": DiscreteTensorSpec( + n=2, + shape=(self.n_agents, 1), + dtype=torch.bool, + ), + }, + shape=(self.n_agents,), + ) + } + ) + + self.action_spec = self.unbatched_action_spec.expand( + *self.batch_size, *self.unbatched_action_spec.shape + ) + self.observation_spec = self.unbatched_observation_spec.expand( + *self.batch_size, *self.unbatched_observation_spec.shape + ) + self.reward_spec = self.unbatched_reward_spec.expand( + *self.batch_size, *self.unbatched_reward_spec.shape + ) + self.done_spec = self.unbatched_done_spec.expand( + *self.batch_size, *self.unbatched_done_spec.shape + ) + + def get_agent_obs_spec(self, i): + camera = BoundedTensorSpec(minimum=0, maximum=1, shape=(32, 32, 3)) + vector_3d = UnboundedContinuousTensorSpec(shape=(3,)) + vector_2d = UnboundedContinuousTensorSpec(shape=(2,)) + lidar = BoundedTensorSpec(minimum=0, maximum=5, shape=(20,)) + + agent_0_obs = UnboundedContinuousTensorSpec(shape=(1,)) + agent_1_obs = BoundedTensorSpec(minimum=0, maximum=3, shape=(1, 2)) + agent_2_obs = UnboundedContinuousTensorSpec(shape=(1, 2, 3)) + + # Agents all have the same camera + # All have vector entry but different shapes + # First 2 have lidar and last sonar + # All have a different key agent_i_obs with different n_dims + if i == 0: + return CompositeSpec( + { + "camera": camera, + "lidar": lidar, + "vector": vector_3d, + "agent_0_obs": agent_0_obs, + } + ) + elif i == 1: + return CompositeSpec( + { + "camera": camera, + "lidar": lidar, + "vector": vector_2d, + "agent_1_obs": agent_1_obs, + } + ) + elif i == 2: + return CompositeSpec( + { + "camera": camera, + "vector": vector_2d, + "agent_2_obs": agent_2_obs, + } + ) + else: + raise ValueError(f"Index {i} undefined for 3 agents") + + def get_agent_action_spec(self, i): + force_3d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(3,)) + force_2d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(2,)) + + # Some have 2d action and some 3d + # TODO Introduce composite heterogeneous actions + if i == 0: + ret = force_3d + elif i == 1: + ret = force_2d + elif i == 2: + ret = force_2d + else: + raise ValueError(f"Index {i} undefined for 3 agents") + + return CompositeSpec({"action": ret}) + + def _reset( + self, + tensordict: TensorDictBase = None, + **kwargs, + ) -> TensorDictBase: + if tensordict is not None and "_reset" in tensordict.keys(): + _reset = tensordict.get("_reset") + self.count[_reset] = self.start_val + else: + self.count[:] = self.start_val + + reset_td = self.observation_spec.zero() + reset_td.apply_(lambda x: x + expand_right(self.count, x.shape)) + reset_td.update(self.output_spec["_done_spec"].zero()) + + assert reset_td.batch_size == self.batch_size + + return reset_td + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + td = self.observation_spec.zero() + self.count += 1 + td.apply_(lambda x: x + expand_right(self.count, x.shape)) + td.update(self.output_spec["_done_spec"].zero()) + td.update(self.output_spec["_reward_spec"].zero()) + + assert td.batch_size == self.batch_size + td[self.done_key] = expand_right( + self.count > self.max_steps, self.done_spec.shape + ) + + return td.select().set("next", td) + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) diff --git a/test/test_env.py b/test/test_env.py index a52f140fcd7..4e1eb04b0a0 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -33,6 +33,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + HeteroCountingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, @@ -1417,7 +1418,6 @@ def test_batch_unlocked(device): env.step(td_expanded) -@pytest.mark.parametrize("device", get_default_devices()) def test_batch_unlocked_with_batch_size(device): env = MockBatchedUnLockedEnv(device, batch_size=torch.Size([2])) assert not env.batch_locked @@ -1669,7 +1669,6 @@ def test_mp_collector(self, nproc): class TestNestedSpecs: @pytest.mark.parametrize("envclass", ["CountingEnv", "NestedCountingEnv"]) def test_nested_env(self, envclass): - if envclass == "CountingEnv": env = CountingEnv() elif envclass == "NestedCountingEnv": @@ -1700,7 +1699,6 @@ def test_nested_env(self, envclass): @pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)]) def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): - env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) td_reset = env.reset() @@ -1750,6 +1748,29 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): ) +class TestHeteroEnvs: + @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) + def test_reset(self, batch_size): + env = HeteroCountingEnv(batch_size=batch_size) + env.reset() + + @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) + def test_rand_step(self, batch_size): + env = HeteroCountingEnv(batch_size=batch_size) + td = env.reset() + assert (td["agents"][..., 0]["agent_0_obs"] == 0).all() + td = env.rand_step() + assert (td["next", "agents"][..., 0]["agent_0_obs"] == 1).all() + td = env.rand_step() + assert (td["next", "agents"][..., 1]["agent_1_obs"] == 2).all() + + @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)]) + def test_rollout_one(self, batch_size, rollout_steps=1): + env = HeteroCountingEnv(batch_size=batch_size) + td = env.rollout(rollout_steps) + td.get("agents") + + @pytest.mark.parametrize( "envclass", [ @@ -1768,6 +1789,7 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): MockBatchedUnLockedEnv, MockSerialEnv, NestedCountingEnv, + HeteroCountingEnv, ], ) def test_mocking_envs(envclass): diff --git a/test/test_specs.py b/test/test_specs.py index 10adac74bdc..6f6efdc1477 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -2147,8 +2147,10 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): def test_to_numpy(self, shape, stack_dim): c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) - c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c = torch.stack([c1, c2], stack_dim) + torch.manual_seed(0) shape = list(shape) @@ -2164,14 +2166,131 @@ def test_to_numpy(self, shape, stack_dim): with pytest.raises(AssertionError): c.to_numpy(val + 1, safe=True) + def test_malformed_stack(self, shape, stack_dim): + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + with pytest.raises(RuntimeError, match="Dtypes differ"): + torch.stack([c1, c2], stack_dim) + + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c2 = UnboundedContinuousTensorSpec(shape=shape, dtype=torch.float32) + c3 = UnboundedDiscreteTensorSpec(shape=shape, dtype=torch.float32) + with pytest.raises( + RuntimeError, + match="Stacking specs cannot occur: Found more than one type of specs in the list.", + ): + torch.stack([c1, c2], stack_dim) + torch.stack([c3, c2], stack_dim) + + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c2 = BoundedTensorSpec(-1, 1, shape=shape + (3,), dtype=torch.float32) + with pytest.raises(RuntimeError, match="Ndims differ"): + torch.stack([c1, c2], stack_dim) -class TestStackComposite: + +class TestDenseStackedCompositeSpecs: def test_stack(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) c2 = c1.clone() c = torch.stack([c1, c2], 0) assert isinstance(c, CompositeSpec) + +class TestLazyStackedCompositeSpecs: + def _get_het_specs(self, stack_dim: int = 0, batch_size=()): + specs = [] + for i in range(3): + specs.append(self._get_sinlge_spec(i, batch_size=batch_size)) + return torch.stack(specs, dim=stack_dim) + + def _get_sinlge_spec(self, i, batch_size=()): + camera = BoundedTensorSpec(minimum=0, maximum=1, shape=(*batch_size, 32, 32, 3)) + vector_3d = UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 3, + ) + ) + vector_2d = UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 2, + ) + ) + lidar = BoundedTensorSpec( + minimum=0, + maximum=5, + shape=( + *batch_size, + 20, + ), + ) + + agent_0_obs = CompositeSpec( + { + "agent_0_obs_0": UnboundedContinuousTensorSpec( + shape=( + *batch_size, + 3, + 1, + ) + ) + }, + shape=(*batch_size, 3), + ) + agent_1_obs = CompositeSpec( + { + "agent_1_obs_0": BoundedTensorSpec( + minimum=0, maximum=3, shape=(*batch_size, 3, 1, 2) + ) + }, + shape=(*batch_size, 3), + ) + agent_2_obs = CompositeSpec( + { + "agent_1_obs_0": UnboundedContinuousTensorSpec( + shape=(*batch_size, 3, 1, 2, 3) + ) + }, + shape=(*batch_size, 3), + ) + + # Agents all have the same camera + # All have vector entry but different shapes + # First 2 have lidar and last sonar + # All have a different key agent_i_obs with different n_dims + if i == 0: + return CompositeSpec( + { + "camera": camera, + "lidar": lidar, + "vector": vector_3d, + "agent_0_obs": agent_0_obs, + }, + shape=batch_size, + ) + elif i == 1: + return CompositeSpec( + { + "camera": camera, + "lidar": lidar, + "vector": vector_2d, + "agent_1_obs": agent_1_obs, + }, + shape=batch_size, + ) + elif i == 2: + return CompositeSpec( + { + "camera": camera, + "vector": vector_2d, + "agent_2_obs": agent_2_obs, + }, + shape=batch_size, + ) + else: + raise AssertionError() + def test_stack_index(self): c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) c2 = CompositeSpec( @@ -2428,6 +2547,253 @@ def test_to_numpy(self): with pytest.raises(AssertionError): c.to_numpy(td_fail, safe=True) + def test_unsqueeze(self): + c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 1) + for unsq in range(-2, 3): + cu = c.unsqueeze(unsq) + shape = list(c.shape) + new_unsq = unsq if unsq >= 0 else c.ndim + unsq + 1 + shape.insert(new_unsq, 1) + assert cu.shape == torch.Size(shape) + cus = cu.squeeze(unsq) + assert c.shape == cus.shape, unsq + assert cus == c + + assert c.squeeze().shape == torch.Size([2, 3]) + + c = self._get_het_specs() + cu = c.unsqueeze(0) + assert cu.shape == torch.Size([1, 3]) + cus = cu.squeeze(0) + assert cus == c + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 2)]) + def test_len(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + assert len(c) == c.shape[0] + assert len(c) == len(c.rand()) + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 2)]) + def test_eq(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + c2 = self._get_het_specs(batch_size=batch_size) + + assert c == c2 and not c != c2 + assert c == c.clone() and not c != c.clone() + + del c2["camera"] + assert not c == c2 and c != c2 + + c2 = self._get_het_specs(batch_size=batch_size) + del c2[0]["lidar"] + + assert not c == c2 and c != c2 + + c2 = self._get_het_specs(batch_size=batch_size) + c2[0]["lidar"].space.minimum += 1 + assert not c == c2 and c != c2 + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 2)]) + @pytest.mark.parametrize("include_nested", [True, False]) + @pytest.mark.parametrize("leaves_only", [True, False]) + def test_del(self, batch_size, include_nested, leaves_only): + c = self._get_het_specs(batch_size=batch_size) + td_c = c.rand() + + keys = list(c.keys(include_nested=include_nested, leaves_only=leaves_only)) + for k in keys: + del c[k] + del td_c[k] + assert len(c.keys(include_nested=include_nested, leaves_only=leaves_only)) == 0 + assert ( + len(td_c.keys(include_nested=include_nested, leaves_only=leaves_only)) == 0 + ) + + keys = list(c[0].keys(include_nested=include_nested, leaves_only=leaves_only)) + for k in keys: + del c[k] + del td_c[k] + assert ( + len(c[0].keys(include_nested=include_nested, leaves_only=leaves_only)) == 0 + ) + assert ( + len(td_c[0].keys(include_nested=include_nested, leaves_only=leaves_only)) + == 0 + ) + with pytest.raises(KeyError): + del c["agent_1_obs_0"] + with pytest.raises(KeyError): + del td_c["agent_1_obs_0"] + + del c[("agent_1_obs", "agent_1_obs_0")] + del td_c[("agent_1_obs", "agent_1_obs_0")] + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 2)]) + def test_is_in(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + td_c = c.rand() + assert c.is_in(td_c) + + del td_c["camera"] + with pytest.raises(KeyError): + assert not c.is_in(td_c) + + td_c = c.rand() + del td_c[("agent_1_obs", "agent_1_obs_0")] + with pytest.raises(KeyError): + assert not c.is_in(td_c) + + td_c = c.rand() + td_c["camera"] += 1 + assert not c.is_in(td_c) + + td_c = c.rand() + td_c[1]["agent_1_obs", "agent_1_obs_0"] += 4 + assert not c.is_in(td_c) + + td_c = c.rand() + td_c[0]["agent_0_obs", "agent_0_obs_0"] += 1 + assert c.is_in(td_c) + + def test_type_check(self): + c = self._get_het_specs() + td_c = c.rand() + + c.type_check(td_c) + c.type_check(td_c["camera"], "camera") + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 2)]) + def test_project(self, batch_size): + c = self._get_het_specs(batch_size=batch_size) + td_c = c.rand() + assert c.is_in(td_c) + val = c.project(td_c) + assert c.is_in(val) + + del td_c["camera"] + with pytest.raises(KeyError): + c.is_in(td_c) + + td_c = c.rand() + del td_c[("agent_1_obs", "agent_1_obs_0")] + with pytest.raises(KeyError): + c.is_in(td_c) + + td_c = c.rand() + td_c["camera"] += 1 + assert not c.is_in(td_c) + val = c.project(td_c) + assert c.is_in(val) + + td_c = c.rand() + td_c[1]["agent_1_obs", "agent_1_obs_0"] += 4 + assert not c.is_in(td_c) + val = c.project(td_c) + assert c.is_in(val) + + td_c = c.rand() + td_c[0]["agent_0_obs", "agent_0_obs_0"] += 1 + assert c.is_in(td_c) + + def test_repr(self): + c = self._get_het_specs() + + expected = f"""LazyStackedCompositeSpec( + fields={{ + camera: BoundedTensorSpec( + shape=torch.Size([3, 32, 32, 3]), + space=ContinuousBox( + minimum=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), + maximum=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + vector: LazyStackedUnboundedContinuousTensorSpec( + shape=torch.Size([3, -1]), space=None, device=cpu, dtype=torch.float32, domain=continuous)}}, + lazy_fields={{ + 0 -> + lidar: BoundedTensorSpec( + shape=torch.Size([20]), + space=ContinuousBox( + minimum=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), + maximum=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + agent_0_obs: CompositeSpec( + agent_0_obs_0: UnboundedContinuousTensorSpec( + shape=torch.Size([3, 1]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([3])), + 1 -> + lidar: BoundedTensorSpec( + shape=torch.Size([20]), + space=ContinuousBox( + minimum=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), + maximum=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + agent_1_obs: CompositeSpec( + agent_1_obs_0: BoundedTensorSpec( + shape=torch.Size([3, 1, 2]), + space=ContinuousBox( + minimum=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True), + maximum=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([3])), + 2 -> + agent_2_obs: CompositeSpec( + agent_1_obs_0: UnboundedContinuousTensorSpec( + shape=torch.Size([3, 1, 2, 3]), + space=None, + device=cpu, + dtype=torch.float32, + domain=continuous), device=cpu, shape=torch.Size([3]))}}, + device=cpu, + shape={torch.Size((3,))}, + stack_dim={c.stack_dim})""" + assert expected == repr(c) + + c = c[0:2] + del c["agent_0_obs"] + del c["agent_1_obs"] + expected = f"""LazyStackedCompositeSpec( + fields={{ + camera: BoundedTensorSpec( + shape=torch.Size([2, 32, 32, 3]), + space=ContinuousBox( + minimum=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), + maximum=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + lidar: BoundedTensorSpec( + shape=torch.Size([2, 20]), + space=ContinuousBox( + minimum=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True), + maximum=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + vector: LazyStackedUnboundedContinuousTensorSpec( + shape=torch.Size([2, -1]), space=None, device=cpu, dtype=torch.float32, domain=continuous)}}, + lazy_fields={{ + }}, + device=cpu, + shape={torch.Size((2,))}, + stack_dim={c.stack_dim})""" + assert expected == repr(c) + # MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080. @pytest.mark.parametrize( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4d69949b964..72560ec5666 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -32,8 +32,8 @@ import numpy as np import torch from tensordict import unravel_key -from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import _getitem_batch_size +from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -391,7 +391,7 @@ def __repr__(self): f"\nmaximum=Tensor(shape={self.maximum.shape}, device={self.maximum.device}, dtype={self.maximum.dtype}, contiguous={self.maximum.is_contiguous()})", " " * 4, ) - return f"{self.__class__.__name__}({min_str}, {max_str})" + return f"{self.__class__.__name__}({min_str},{max_str})" def __eq__(self, other): return ( @@ -839,46 +839,12 @@ def __getitem__(self, item): return out return torch.stack(list(out), 0) - @property - def shape(self): - shape = list(self._specs[0].shape) - dim = self.dim - if dim < 0: - dim = len(shape) + dim + 1 - shape.insert(dim, len(self._specs)) - return torch.Size(shape) - def clone(self) -> T: return torch.stack([spec.clone() for spec in self._specs], 0) - def expand(self, *shape): - if len(shape) == 1 and not isinstance(shape[0], (int,)): - return self.expand(*shape[0]) - expand_shape = shape[: -len(self.shape)] - existing_shape = self.shape - shape_check = shape[-len(self.shape) :] - for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): - if size1 != size2 and size1 != 1: - raise RuntimeError( - f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" - ) - elif size1 != size2 and size1 == 1 and _i == self.dim: - # if we're expanding along the stack dim we just need to clone the existing spec - return torch.stack( - [self._specs[0].clone() for _ in range(size2)], self.dim - ).expand(*shape) - if _i != len(self.shape) - 1: - raise RuntimeError( - f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." - ) - # remove the stack dim from the expanded shape, which we know to match - unstack_shape = list(expand_shape) + [ - s for i, s in enumerate(shape_check) if i != self.dim - ] - return torch.stack( - [spec.expand(unstack_shape) for spec in self._specs], - self.dim + len(expand_shape), - ) + @property + def stack_dim(self): + return self.dim def zero(self, shape=None) -> TensorDictBase: if shape is not None: @@ -914,11 +880,20 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): @property def space(self): - return self._specs[0].space + raise NotImplementedError def __eq__(self, other): - # requires unbind to be implemented - pass + if not isinstance(other, LazyStackedTensorSpec): + return False + if len(self._specs) != len(other._specs): + return False + for _spec1, _spec2 in zip(self._specs, other._specs): + if _spec1 != _spec2: + return False + return True + + def __len__(self): + return self.shape[0] def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: if safe is None: @@ -933,29 +908,23 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: spec.assert_is_in(v) return val.detach().cpu().numpy() - def __len__(self): - pass - - def project(self, val: TensorDictBase) -> TensorDictBase: - pass + def _project(self, val: TensorDictBase) -> TensorDictBase: + raise NotImplementedError def __repr__(self): shape_str = "shape=" + str(self.shape) - space_str = "space=" + str(self._specs[0].space) device_str = "device=" + str(self.device) dtype_str = "dtype=" + str(self.dtype) domain_str = "domain=" + str(self._specs[0].domain) - sub_string = ", ".join( - [shape_str, space_str, device_str, dtype_str, domain_str] - ) - string = f"{self.__class__.__name__}(\n {sub_string})" + sub_string = ", ".join([shape_str, device_str, dtype_str, domain_str]) + string = f"LazyStacked{self._specs[0].__class__.__name__}(\n {sub_string})" return string def __iter__(self): - pass + raise NotImplementedError def __setitem__(self, key, value): - pass + raise NotImplementedError @property def device(self) -> DEVICE_TYPING: @@ -979,6 +948,61 @@ def set(self, name, spec): ) self._specs[name] = spec + def is_in(self, val) -> bool: + raise NotImplementedError + + @property + def shape(self): + first_shape = self._specs[0].shape + shape = [] + for i in range(len(first_shape)): + homo_dim = True + for spec in self._specs: + if spec.shape[i] != first_shape[i]: + homo_dim = False + break + shape.append(first_shape[i] if homo_dim else -1) + + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self._specs)) + return torch.Size(shape) + + def expand(self, *shape): + if len(shape) == 1 and not isinstance(shape[0], (int,)): + return self.expand(*shape[0]) + expand_shape = shape[: -len(self.shape)] + existing_shape = self.shape + shape_check = shape[-len(self.shape) :] + for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + if size1 != size2 and size1 != 1: + raise RuntimeError( + f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" + ) + elif size1 != size2 and size1 == 1 and _i == self.dim: + # if we're expanding along the stack dim we just need to clone the existing spec + return torch.stack( + [self._specs[0].clone() for _ in range(size2)], self.dim + ).expand(*shape) + if _i != len(self.shape) - 1: + raise RuntimeError( + f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + ) + # remove the stack dim from the expanded shape, which we know to match + shape_check = [s for i, s in enumerate(shape_check) if i != self.dim] + specs = [] + for spec in self._specs: + spec_shape = [] + for dim_check, spec_dim in zip(shape_check, spec.shape): + spec_shape.append(dim_check if dim_check != -1 else spec_dim) + unstack_shape = list(expand_shape) + list(spec_shape) + specs.append(spec.expand(unstack_shape)) + return torch.stack( + specs, + self.dim + len(expand_shape), + ) + @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): @@ -1031,7 +1055,6 @@ def __init__( dtype: Optional[Union[str, torch.dtype]] = torch.long, use_register: bool = False, ): - dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register space = DiscreteBox(n) @@ -2756,14 +2779,14 @@ def type_check( value = {selected_keys: value} selected_keys = [selected_keys] - for _key in self: + for _key in self.keys(): if self[_key] is not None and ( selected_keys is None or _key in selected_keys ): self._specs[_key].type_check(value[_key], _key) def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - for (key, item) in self._specs.items(): + for key, item in self._specs.items(): if item is None: continue if not item.is_in(val.get(key)): @@ -3116,10 +3139,17 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): """ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: - pass + raise NotImplementedError def __eq__(self, other): - pass + if not isinstance(other, LazyStackedCompositeSpec): + return False + if len(self._specs) != len(other._specs): + return False + for _spec1, _spec2 in zip(self._specs, other._specs): + if _spec1 != _spec2: + return False + return True def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: if safe is None: @@ -3135,14 +3165,22 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} def __len__(self): - pass + return self.shape[0] - def values(self): - for key in self.keys(): + def values( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield self[key] - def items(self): - for key in self.keys(): + def items( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield key, self[key] def keys( @@ -3150,47 +3188,111 @@ def keys( include_nested: bool = False, leaves_only: bool = False, ) -> KeysView: - return self._specs[0].keys( + keys = self._specs[0].keys( include_nested=include_nested, leaves_only=leaves_only ) + keys = set(keys) + for spec in self._specs[1:]: + keys = keys.intersection(spec.keys(include_nested, leaves_only)) + return sorted(keys, key=str) def project(self, val: TensorDictBase) -> TensorDictBase: - pass - - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - pass + vals = [] + for spec, subval in zip(self._specs, val.unbind(self.dim)): + if not spec.is_in(subval): + vals.append(spec.project(subval)) + else: + vals.append(subval) + res = torch.stack(vals, dim=self.dim) + if not isinstance(val, LazyStackedTensorDict): + res = res.to_tensordict() + return res def type_check( self, value: Union[torch.Tensor, TensorDictBase], - selected_keys: Union[str, Optional[Sequence[str]]] = None, + selected_keys: Union[NestedKey, Optional[Sequence[NestedKey]]] = None, ): - pass + if selected_keys is None: + if isinstance(value, torch.Tensor): + raise ValueError( + "value must be of type TensorDictBase when key is None" + ) + for spec, subvalue in zip(self._specs, value.unbind(self.dim)): + spec.type_check(subvalue) + else: + if isinstance(value, torch.Tensor) and isinstance(selected_keys, str): + value = {selected_keys: value} + selected_keys = [selected_keys] + for _key in self.keys(): + if self[_key] is not None and _key in selected_keys: + self[_key].type_check(value[_key], _key) def __repr__(self) -> str: sub_str = ",\n".join( [indent(f"{k}: {repr(item)}", 4 * " ") for k, item in self.items()] ) - device_str = f"device={self._specs[0].device}" - shape_str = f"shape={self.shape}" - sub_str = ", ".join([sub_str, device_str, shape_str]) - return ( - f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" + sub_str = indent(f"fields={{\n{', '.join([sub_str])}}}", 4 * " ") + lazy_key_str = self.repr_lay_keys() + device_str = indent(f"device={self._specs[0].device}", 4 * " ") + shape_str = indent(f"shape={self.shape}", 4 * " ") + stack_dim = indent(f"stack_dim={self.dim}", 4 * " ") + string = ",\n".join([sub_str, lazy_key_str, device_str, shape_str, stack_dim]) + return f"LazyStackedCompositeSpec(\n{string})" + + def repr_lay_keys(self): + keys = set(self.keys()) + lazy_keys = [ + ",\n".join( + [ + indent(f"{k}: {repr(spec[k])}", 4 * " ") + for k in spec.keys() + if k not in keys + ] + ) + for spec in self._specs + ] + lazy_key_str = ",\n".join( + [ + indent(f"{i} ->\n{line}", 4 * " ") + for i, line in enumerate(lazy_keys) + if line != "" + ] ) + return indent(f"lazy_fields={{\n{lazy_key_str}}}", 4 * " ") + + def is_in(self, val) -> bool: + for spec, subval in zip(self._specs, val.unbind(self.dim)): + if not spec.is_in(subval): + return False + return True + def encode( self, vals: Dict[str, Any], ignore_device: bool = False ) -> Dict[str, torch.Tensor]: - pass + raise NotImplementedError - def __delitem__(self, key): - pass + def __delitem__(self, key: NestedKey): + """Deletes a key only if present in all stacked specs.""" + at_least_one_deletion = False + for spec in self._specs: + try: + del spec[key] + at_least_one_deletion = True + except KeyError: + continue + if not at_least_one_deletion: + raise KeyError( + f"Key {key} must be present in at least one of the stacked specs" + ) + return self def __iter__(self): - pass + raise NotImplementedError def __setitem__(self, key, value): - pass + raise NotImplementedError @property def device(self) -> DEVICE_TYPING: @@ -3214,6 +3316,100 @@ def set(self, name, spec): ) self._specs[name] = spec + def unsqueeze(self, dim: int): + if dim < 0: + new_dim = dim + len(self.shape) + 1 + else: + new_dim = dim + if new_dim > len(self.shape) or new_dim < 0: + raise ValueError(f"Cannot unsqueeze along dim {dim}.") + if new_dim > self.dim: + # unsqueeze 2, stack is on 1 => unsqueeze 1, stack along 1 + new_stack_dim = self.dim + new_dim = new_dim - 1 + else: + # unsqueeze 0, stack is on 1 => unsqueeze 0, stack on 1 + new_stack_dim = self.dim + 1 + return LazyStackedCompositeSpec( + *[spec.unsqueeze(new_dim) for spec in self._specs], dim=new_stack_dim + ) + + def squeeze(self, dim: int = None): + if dim is None: + size = self.shape + if len(size) == 1 or size.count(1) == 0: + return self + first_singleton_dim = size.index(1) + + squeezed_dict = self.squeeze(first_singleton_dim) + return squeezed_dict.squeeze(dim=None) + + if dim < 0: + new_dim = self.ndim + dim + else: + new_dim = dim + + if self.shape and (new_dim >= self.ndim or new_dim < 0): + raise RuntimeError( + f"squeezing is allowed for dims comprised between 0 and " + f"spec.ndim only. Got dim={dim} and shape" + f"={self.shape}." + ) + + if new_dim >= self.ndim or self.shape[new_dim] != 1: + return self + + if new_dim == self.dim: + return self._specs[0] + if new_dim > self.dim: + # squeeze 2, stack is on 1 => squeeze 1, stack along 1 + new_stack_dim = self.dim + new_dim = new_dim - 1 + else: + # squeeze 0, stack is on 1 => squeeze 0, stack on 1 + new_stack_dim = self.dim - 1 + return LazyStackedCompositeSpec( + *[spec.squeeze(new_dim) for spec in self._specs], dim=new_stack_dim + ) + + @property + def shape(self): + shape = list(self._specs[0].shape) + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self._specs)) + return torch.Size(shape) + + def expand(self, *shape): + if len(shape) == 1 and not isinstance(shape[0], (int,)): + return self.expand(*shape[0]) + expand_shape = shape[: -len(self.shape)] + existing_shape = self.shape + shape_check = shape[-len(self.shape) :] + for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + if size1 != size2 and size1 != 1: + raise RuntimeError( + f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" + ) + elif size1 != size2 and size1 == 1 and _i == self.dim: + # if we're expanding along the stack dim we just need to clone the existing spec + return torch.stack( + [self._specs[0].clone() for _ in range(size2)], self.dim + ).expand(*shape) + if _i != len(self.shape) - 1: + raise RuntimeError( + f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + ) + # remove the stack dim from the expanded shape, which we know to match + unstack_shape = list(expand_shape) + [ + s for i, s in enumerate(shape_check) if i != self.dim + ] + return torch.stack( + [spec.expand(unstack_shape) for spec in self._specs], + self.dim + len(expand_shape), + ) + # for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack) @@ -3228,14 +3424,19 @@ def _stack_specs(list_of_spec, dim, out=None): spec0 = list_of_spec[0] if isinstance(spec0, TensorSpec): device = spec0.device + all_equal = True for spec in list_of_spec[1:]: - if not isinstance(spec, TensorSpec): + if not isinstance(spec, spec0.__class__): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of specs in the list." ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + if spec.dtype != spec0.dtype: + raise RuntimeError(f"Dtypes differ, got {spec0.dtype} and {spec.dtype}") + if spec.ndim != spec0.ndim: + raise RuntimeError(f"Ndims differ, got {spec0.ndim} and {spec.ndim}") all_equal = all_equal and spec == spec0 if all_equal: shape = list(spec0.shape) @@ -3269,6 +3470,8 @@ def _stack_composite_specs(list_of_spec, dim, out=None): ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + if spec.shape != spec0.shape: + raise RuntimeError(f"Shapes differ, got {spec.shape} and {spec0.shape}") all_equal = all_equal and spec == spec0 if all_equal: shape = list(spec0.shape)