Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2045,18 +2045,36 @@ def test_stack_unboundeddiscrete_rand(self, shape, stack_dim):
shape = (*shape,)
c1 = UnboundedDiscreteTensorSpec(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
c = torch.stack([c1, c2], stack_dim)
r = c.rand()
assert r.shape == c.shape

def test_stack_unboundeddiscrete_zero(self, shape, stack_dim):
shape = (*shape,)
c1 = UnboundedDiscreteTensorSpec(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
c = torch.stack([c1, c2], stack_dim)
r = c.zero()
assert r.shape == c.shape

def test_to_numpy(self, shape, stack_dim):
c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
c = torch.stack([c1, c2], stack_dim)

shape = list(shape)
shape.insert(stack_dim, 2)
shape = tuple(shape)

val = 2 * torch.rand(torch.Size(shape)) - 1

val_np = c.to_numpy(val)
assert isinstance(val_np, np.ndarray)
assert (val.numpy() == val_np).all()

with pytest.raises(AssertionError):
c.to_numpy(val + 1)


class TestStackComposite:
def test_stack(self):
Expand Down Expand Up @@ -2303,6 +2321,24 @@ def test_clone(self):
assert cclone[0] is not c[0]
assert cclone[0] == c[0]

def test_to_numpy(self):
c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3))
c2 = CompositeSpec(
a=BoundedTensorSpec(-1, 1, shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], 0)
for _ in range(100):
r = c.rand()
for key, value in c.to_numpy(r).items():
spec = c[key]
assert (spec.to_numpy(r[key]) == value).all()

td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3])
with pytest.raises(AssertionError):
c.to_numpy(td_fail)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
53 changes: 20 additions & 33 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,38 +668,23 @@ def __eq__(self, other):
# requires unbind to be implemented
pass

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
pass
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict:
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
"Size of LazyStackedTensorSpec and val differ along the stacking "
"dimension"
)
for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)):
spec.assert_is_in(v)
return val.detach().cpu().numpy()

def __len__(self):
pass

def values(self) -> ValuesView:
pass

def items(self) -> ItemsView:
pass

def keys(
self,
include_nested: bool = False,
leaves_only: bool = False,
) -> KeysView:
pass

def project(self, val: TensorDictBase) -> TensorDictBase:
pass

def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
pass

def type_check(
self,
value: Union[torch.Tensor, TensorDictBase],
selected_keys: Union[str, Optional[Sequence[str]]] = None,
):
pass

def __repr__(self):
shape_str = "shape=" + str(self.shape)
space_str = "space=" + str(self._specs[0].space)
Expand All @@ -712,12 +697,6 @@ def __repr__(self):
string = f"{self.__class__.__name__}(\n {sub_string})"
return string

def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]:
pass

def __delitem__(self, key):
pass

def __iter__(self):
pass

Expand All @@ -726,7 +705,7 @@ def __setitem__(self, key, value):

@property
def device(self) -> DEVICE_TYPING:
pass
return self._specs[0].device

@property
def ndim(self):
Expand Down Expand Up @@ -2591,7 +2570,15 @@ def __eq__(self, other):
pass

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
pass
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
"Size of LazyStackedCompositeSpec and val differ along the "
"stacking dimension"
)
for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)):
spec.assert_is_in(v)
return {key: self[key].to_numpy(val) for key, val in val.items()}

def __len__(self):
pass
Expand Down