diff --git a/test/test_specs.py b/test/test_specs.py index 10adac74bdc..3c2461251ce 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -2428,6 +2428,69 @@ def test_to_numpy(self): with pytest.raises(AssertionError): c.to_numpy(td_fail, safe=True) + def test_unsqueeze(self): + c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 1) + for unsq in range(-2, 3): + cu = c.unsqueeze(unsq) + shape = list(c.shape) + new_unsq = unsq if unsq >= 0 else c.ndim + unsq + 1 + shape.insert(new_unsq, 1) + assert cu.shape == torch.Size(shape) + cus = cu.squeeze(unsq) + assert c.shape == cus.shape, unsq + assert cus == c + + assert c.squeeze().shape == torch.Size([2, 3]) + + specs = [ + CompositeSpec( + { + "observation_0": UnboundedContinuousTensorSpec( + shape=torch.Size([128, 128, 3]), + device="cpu", + dtype=torch.float32, + ) + } + ), + CompositeSpec( + { + "observation_1": UnboundedContinuousTensorSpec( + shape=torch.Size([128, 128, 3]), + device="cpu", + dtype=torch.float32, + ) + } + ), + CompositeSpec( + { + "observation_2": UnboundedContinuousTensorSpec( + shape=torch.Size([128, 128, 3]), + device="cpu", + dtype=torch.float32, + ) + } + ), + CompositeSpec( + { + "observation_3": UnboundedContinuousTensorSpec( + shape=torch.Size([4]), device="cpu", dtype=torch.float32 + ) + } + ), + ] + + c = torch.stack(specs, dim=0) + cu = c.unsqueeze(0) + assert cu.shape == torch.Size([1, 4]) + cus = cu.squeeze(0) + assert cus == c + # MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080. @pytest.mark.parametrize( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4d69949b964..c67b68fd305 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -934,10 +934,10 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: return val.detach().cpu().numpy() def __len__(self): - pass + raise NotImplementedError def project(self, val: TensorDictBase) -> TensorDictBase: - pass + raise NotImplementedError def __repr__(self): shape_str = "shape=" + str(self.shape) @@ -952,10 +952,10 @@ def __repr__(self): return string def __iter__(self): - pass + raise NotImplementedError def __setitem__(self, key, value): - pass + raise NotImplementedError @property def device(self) -> DEVICE_TYPING: @@ -979,6 +979,12 @@ def set(self, name, spec): ) self._specs[name] = spec + def is_in(self, val) -> bool: + isin = True + for spec, subval in zip(self._specs, val.unbind(self.dim)): + isin = isin and spec.is_in(subval) + return isin + @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): @@ -3119,7 +3125,14 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N pass def __eq__(self, other): - pass + if not isinstance(other, LazyStackedCompositeSpec): + return False + if len(self._specs) != len(other._specs): + return False + for _spec1, _spec2 in zip(self._specs, other._specs): + if _spec1 != _spec2: + return False + return True def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: if safe is None: @@ -3135,14 +3148,22 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} def __len__(self): - pass + raise NotImplementedError - def values(self): - for key in self.keys(): + def values( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield self[key] - def items(self): - for key in self.keys(): + def items( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield key, self[key] def keys( @@ -3150,22 +3171,23 @@ def keys( include_nested: bool = False, leaves_only: bool = False, ) -> KeysView: - return self._specs[0].keys( + keys = self._specs[0].keys( include_nested=include_nested, leaves_only=leaves_only ) + keys = set(keys) + for spec in self._specs[1:]: + keys = keys.intersection(spec.keys(include_nested, leaves_only)) + return sorted(keys, key=str) def project(self, val: TensorDictBase) -> TensorDictBase: - pass - - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - pass + raise NotImplementedError def type_check( self, value: Union[torch.Tensor, TensorDictBase], selected_keys: Union[str, Optional[Sequence[str]]] = None, ): - pass + raise NotImplementedError def __repr__(self) -> str: sub_str = ",\n".join( @@ -3178,19 +3200,25 @@ def __repr__(self) -> str: f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" ) + def is_in(self, val) -> bool: + isin = True + for spec, subval in zip(self._specs, val.unbind(self.dim)): + isin = isin and spec.is_in(subval) + return isin + def encode( self, vals: Dict[str, Any], ignore_device: bool = False ) -> Dict[str, torch.Tensor]: - pass + raise NotImplementedError def __delitem__(self, key): - pass + raise NotImplementedError def __iter__(self): - pass + raise NotImplementedError def __setitem__(self, key, value): - pass + raise NotImplementedError @property def device(self) -> DEVICE_TYPING: @@ -3214,6 +3242,63 @@ def set(self, name, spec): ) self._specs[name] = spec + def unsqueeze(self, dim: int): + if dim < 0: + new_dim = dim + len(self.shape) + 1 + else: + new_dim = dim + if new_dim > len(self.shape) or new_dim < 0: + raise ValueError(f"Cannot unsqueeze along dim {dim}.") + new_stack_dim = self.dim if self.dim < new_dim else self.dim + 1 + if new_dim > self.dim: + # unsqueeze 2, stack is on 1 => unsqueeze 1, stack along 1 + new_stack_dim = self.dim + new_dim = new_dim - 1 + else: + # unsqueeze 0, stack is on 1 => unsqueeze 0, stack on 1 + new_stack_dim = self.dim + 1 + return LazyStackedCompositeSpec( + *[spec.unsqueeze(new_dim) for spec in self._specs], dim=new_stack_dim + ) + + def squeeze(self, dim: int=None): + if dim is None: + size = self.shape + if len(size) == 1 or size.count(1) == 0: + return self + first_singleton_dim = size.index(1) + + squeezed_dict = self.squeeze(first_singleton_dim) + return squeezed_dict.squeeze(dim=None) + + if dim < 0: + new_dim = self.ndim + dim + else: + new_dim = dim + + if self.shape and (new_dim >= self.ndim or new_dim < 0): + raise RuntimeError( + f"squeezing is allowed for dims comprised between 0 and " + f"spec.ndim only. Got dim={dim} and shape" + f"={self.shape}." + ) + + if new_dim >= self.ndim or self.shape[new_dim] != 1: + return self + + if new_dim == self.dim: + return self._specs[0] + if new_dim > self.dim: + # squeeze 2, stack is on 1 => squeeze 1, stack along 1 + new_stack_dim = self.dim + new_dim = new_dim - 1 + else: + # squeeze 0, stack is on 1 => squeeze 0, stack on 1 + new_stack_dim = self.dim - 1 + return LazyStackedCompositeSpec( + *[spec.squeeze(new_dim) for spec in self._specs], dim=new_stack_dim + ) + # for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack)