Skip to content

Commit 4bf6c5f

Browse files
authored
[Feature] Various improvements to LazyStacked specs (#965)
1 parent c1acefd commit 4bf6c5f

File tree

2 files changed

+58
-35
lines changed

2 files changed

+58
-35
lines changed

test/test_specs.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,18 +2045,36 @@ def test_stack_unboundeddiscrete_rand(self, shape, stack_dim):
20452045
shape = (*shape,)
20462046
c1 = UnboundedDiscreteTensorSpec(shape=shape)
20472047
c2 = c1.clone()
2048-
c = torch.stack([c1, c2], 0)
2048+
c = torch.stack([c1, c2], stack_dim)
20492049
r = c.rand()
20502050
assert r.shape == c.shape
20512051

20522052
def test_stack_unboundeddiscrete_zero(self, shape, stack_dim):
20532053
shape = (*shape,)
20542054
c1 = UnboundedDiscreteTensorSpec(shape=shape)
20552055
c2 = c1.clone()
2056-
c = torch.stack([c1, c2], 0)
2056+
c = torch.stack([c1, c2], stack_dim)
20572057
r = c.zero()
20582058
assert r.shape == c.shape
20592059

2060+
def test_to_numpy(self, shape, stack_dim):
2061+
c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
2062+
c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
2063+
c = torch.stack([c1, c2], stack_dim)
2064+
2065+
shape = list(shape)
2066+
shape.insert(stack_dim, 2)
2067+
shape = tuple(shape)
2068+
2069+
val = 2 * torch.rand(torch.Size(shape)) - 1
2070+
2071+
val_np = c.to_numpy(val)
2072+
assert isinstance(val_np, np.ndarray)
2073+
assert (val.numpy() == val_np).all()
2074+
2075+
with pytest.raises(AssertionError):
2076+
c.to_numpy(val + 1)
2077+
20602078

20612079
class TestStackComposite:
20622080
def test_stack(self):
@@ -2303,6 +2321,24 @@ def test_clone(self):
23032321
assert cclone[0] is not c[0]
23042322
assert cclone[0] == c[0]
23052323

2324+
def test_to_numpy(self):
2325+
c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3))
2326+
c2 = CompositeSpec(
2327+
a=BoundedTensorSpec(-1, 1, shape=(1, 3)),
2328+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2329+
shape=(1, 3),
2330+
)
2331+
c = torch.stack([c1, c2], 0)
2332+
for _ in range(100):
2333+
r = c.rand()
2334+
for key, value in c.to_numpy(r).items():
2335+
spec = c[key]
2336+
assert (spec.to_numpy(r[key]) == value).all()
2337+
2338+
td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3])
2339+
with pytest.raises(AssertionError):
2340+
c.to_numpy(td_fail)
2341+
23062342

23072343
if __name__ == "__main__":
23082344
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/tensor_specs.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -668,38 +668,23 @@ def __eq__(self, other):
668668
# requires unbind to be implemented
669669
pass
670670

671-
def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
672-
pass
671+
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict:
672+
if safe:
673+
if val.shape[self.dim] != len(self._specs):
674+
raise ValueError(
675+
"Size of LazyStackedTensorSpec and val differ along the stacking "
676+
"dimension"
677+
)
678+
for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)):
679+
spec.assert_is_in(v)
680+
return val.detach().cpu().numpy()
673681

674682
def __len__(self):
675683
pass
676684

677-
def values(self) -> ValuesView:
678-
pass
679-
680-
def items(self) -> ItemsView:
681-
pass
682-
683-
def keys(
684-
self,
685-
include_nested: bool = False,
686-
leaves_only: bool = False,
687-
) -> KeysView:
688-
pass
689-
690685
def project(self, val: TensorDictBase) -> TensorDictBase:
691686
pass
692687

693-
def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
694-
pass
695-
696-
def type_check(
697-
self,
698-
value: Union[torch.Tensor, TensorDictBase],
699-
selected_keys: Union[str, Optional[Sequence[str]]] = None,
700-
):
701-
pass
702-
703688
def __repr__(self):
704689
shape_str = "shape=" + str(self.shape)
705690
space_str = "space=" + str(self._specs[0].space)
@@ -712,12 +697,6 @@ def __repr__(self):
712697
string = f"{self.__class__.__name__}(\n {sub_string})"
713698
return string
714699

715-
def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]:
716-
pass
717-
718-
def __delitem__(self, key):
719-
pass
720-
721700
def __iter__(self):
722701
pass
723702

@@ -726,7 +705,7 @@ def __setitem__(self, key, value):
726705

727706
@property
728707
def device(self) -> DEVICE_TYPING:
729-
pass
708+
return self._specs[0].device
730709

731710
@property
732711
def ndim(self):
@@ -2591,7 +2570,15 @@ def __eq__(self, other):
25912570
pass
25922571

25932572
def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
2594-
pass
2573+
if safe:
2574+
if val.shape[self.dim] != len(self._specs):
2575+
raise ValueError(
2576+
"Size of LazyStackedCompositeSpec and val differ along the "
2577+
"stacking dimension"
2578+
)
2579+
for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)):
2580+
spec.assert_is_in(v)
2581+
return {key: self[key].to_numpy(val) for key, val in val.items()}
25952582

25962583
def __len__(self):
25972584
pass

0 commit comments

Comments
 (0)