diff --git a/test/test_specs.py b/test/test_specs.py index 7472a2b8e08..aaf30a0906e 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -2045,7 +2045,7 @@ def test_stack_unboundeddiscrete_rand(self, shape, stack_dim): shape = (*shape,) c1 = UnboundedDiscreteTensorSpec(shape=shape) c2 = c1.clone() - c = torch.stack([c1, c2], 0) + c = torch.stack([c1, c2], stack_dim) r = c.rand() assert r.shape == c.shape @@ -2053,10 +2053,28 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): shape = (*shape,) c1 = UnboundedDiscreteTensorSpec(shape=shape) c2 = c1.clone() - c = torch.stack([c1, c2], 0) + c = torch.stack([c1, c2], stack_dim) r = c.zero() assert r.shape == c.shape + def test_to_numpy(self, shape, stack_dim): + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c = torch.stack([c1, c2], stack_dim) + + shape = list(shape) + shape.insert(stack_dim, 2) + shape = tuple(shape) + + val = 2 * torch.rand(torch.Size(shape)) - 1 + + val_np = c.to_numpy(val) + assert isinstance(val_np, np.ndarray) + assert (val.numpy() == val_np).all() + + with pytest.raises(AssertionError): + c.to_numpy(val + 1) + class TestStackComposite: def test_stack(self): @@ -2303,6 +2321,24 @@ def test_clone(self): assert cclone[0] is not c[0] assert cclone[0] == c[0] + def test_to_numpy(self): + c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 0) + for _ in range(100): + r = c.rand() + for key, value in c.to_numpy(r).items(): + spec = c[key] + assert (spec.to_numpy(r[key]) == value).all() + + td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3]) + with pytest.raises(AssertionError): + c.to_numpy(td_fail) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0f36db422be..54ea17fb4e2 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -668,38 +668,23 @@ def __eq__(self, other): # requires unbind to be implemented pass - def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - pass + def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict: + if safe: + if val.shape[self.dim] != len(self._specs): + raise ValueError( + "Size of LazyStackedTensorSpec and val differ along the stacking " + "dimension" + ) + for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): + spec.assert_is_in(v) + return val.detach().cpu().numpy() def __len__(self): pass - def values(self) -> ValuesView: - pass - - def items(self) -> ItemsView: - pass - - def keys( - self, - include_nested: bool = False, - leaves_only: bool = False, - ) -> KeysView: - pass - def project(self, val: TensorDictBase) -> TensorDictBase: pass - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - pass - - def type_check( - self, - value: Union[torch.Tensor, TensorDictBase], - selected_keys: Union[str, Optional[Sequence[str]]] = None, - ): - pass - def __repr__(self): shape_str = "shape=" + str(self.shape) space_str = "space=" + str(self._specs[0].space) @@ -712,12 +697,6 @@ def __repr__(self): string = f"{self.__class__.__name__}(\n {sub_string})" return string - def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: - pass - - def __delitem__(self, key): - pass - def __iter__(self): pass @@ -726,7 +705,7 @@ def __setitem__(self, key, value): @property def device(self) -> DEVICE_TYPING: - pass + return self._specs[0].device @property def ndim(self): @@ -2591,7 +2570,15 @@ def __eq__(self, other): pass def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: - pass + if safe: + if val.shape[self.dim] != len(self._specs): + raise ValueError( + "Size of LazyStackedCompositeSpec and val differ along the " + "stacking dimension" + ) + for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): + spec.assert_is_in(v) + return {key: self[key].to_numpy(val) for key, val in val.items()} def __len__(self): pass