Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,69 @@ def test_to_numpy(self):
with pytest.raises(AssertionError):
c.to_numpy(td_fail, safe=True)

def test_unsqueeze(self):
c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3))
c2 = CompositeSpec(
a=BoundedTensorSpec(-1, 1, shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], 1)
for unsq in range(-2, 3):
cu = c.unsqueeze(unsq)
shape = list(c.shape)
new_unsq = unsq if unsq >= 0 else c.ndim + unsq + 1
shape.insert(new_unsq, 1)
assert cu.shape == torch.Size(shape)
cus = cu.squeeze(unsq)
assert c.shape == cus.shape, unsq
assert cus == c

assert c.squeeze().shape == torch.Size([2, 3])

specs = [
CompositeSpec(
{
"observation_0": UnboundedContinuousTensorSpec(
shape=torch.Size([128, 128, 3]),
device="cpu",
dtype=torch.float32,
)
}
),
CompositeSpec(
{
"observation_1": UnboundedContinuousTensorSpec(
shape=torch.Size([128, 128, 3]),
device="cpu",
dtype=torch.float32,
)
}
),
CompositeSpec(
{
"observation_2": UnboundedContinuousTensorSpec(
shape=torch.Size([128, 128, 3]),
device="cpu",
dtype=torch.float32,
)
}
),
CompositeSpec(
{
"observation_3": UnboundedContinuousTensorSpec(
shape=torch.Size([4]), device="cpu", dtype=torch.float32
)
}
),
]

c = torch.stack(specs, dim=0)
cu = c.unsqueeze(0)
assert cu.shape == torch.Size([1, 4])
cus = cu.squeeze(0)
assert cus == c


# MultiDiscreteTensorSpec: Pending resolution of https://github.com/pytorch/pytorch/issues/100080.
@pytest.mark.parametrize(
Expand Down
125 changes: 105 additions & 20 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,10 +934,10 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict:
return val.detach().cpu().numpy()

def __len__(self):
pass
raise NotImplementedError

def project(self, val: TensorDictBase) -> TensorDictBase:
pass
raise NotImplementedError

def __repr__(self):
shape_str = "shape=" + str(self.shape)
Expand All @@ -952,10 +952,10 @@ def __repr__(self):
return string

def __iter__(self):
pass
raise NotImplementedError

def __setitem__(self, key, value):
pass
raise NotImplementedError

@property
def device(self) -> DEVICE_TYPING:
Expand All @@ -979,6 +979,12 @@ def set(self, name, spec):
)
self._specs[name] = spec

def is_in(self, val) -> bool:
isin = True
for spec, subval in zip(self._specs, val.unbind(self.dim)):
isin = isin and spec.is_in(subval)
return isin


@dataclass(repr=False)
class OneHotDiscreteTensorSpec(TensorSpec):
Expand Down Expand Up @@ -3119,7 +3125,14 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N
pass

def __eq__(self, other):
pass
if not isinstance(other, LazyStackedCompositeSpec):
return False
if len(self._specs) != len(other._specs):
return False
for _spec1, _spec2 in zip(self._specs, other._specs):
if _spec1 != _spec2:
return False
return True

def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
if safe is None:
Expand All @@ -3135,37 +3148,46 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
return {key: self[key].to_numpy(val) for key, val in val.items()}

def __len__(self):
pass
raise NotImplementedError

def values(self):
for key in self.keys():
def values(
self,
include_nested: bool = False,
leaves_only: bool = False,
):
for key in self.keys(include_nested=include_nested, leaves_only=leaves_only):
yield self[key]

def items(self):
for key in self.keys():
def items(
self,
include_nested: bool = False,
leaves_only: bool = False,
):
for key in self.keys(include_nested=include_nested, leaves_only=leaves_only):
yield key, self[key]

def keys(
self,
include_nested: bool = False,
leaves_only: bool = False,
) -> KeysView:
return self._specs[0].keys(
keys = self._specs[0].keys(
include_nested=include_nested, leaves_only=leaves_only
)
keys = set(keys)
for spec in self._specs[1:]:
keys = keys.intersection(spec.keys(include_nested, leaves_only))
return sorted(keys, key=str)

def project(self, val: TensorDictBase) -> TensorDictBase:
pass

def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
pass
raise NotImplementedError

def type_check(
self,
value: Union[torch.Tensor, TensorDictBase],
selected_keys: Union[str, Optional[Sequence[str]]] = None,
):
pass
raise NotImplementedError

def __repr__(self) -> str:
sub_str = ",\n".join(
Expand All @@ -3178,19 +3200,25 @@ def __repr__(self) -> str:
f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})"
)

def is_in(self, val) -> bool:
isin = True
for spec, subval in zip(self._specs, val.unbind(self.dim)):
isin = isin and spec.is_in(subval)
return isin

def encode(
self, vals: Dict[str, Any], ignore_device: bool = False
) -> Dict[str, torch.Tensor]:
pass
raise NotImplementedError

def __delitem__(self, key):
pass
raise NotImplementedError

def __iter__(self):
pass
raise NotImplementedError

def __setitem__(self, key, value):
pass
raise NotImplementedError

@property
def device(self) -> DEVICE_TYPING:
Expand All @@ -3214,6 +3242,63 @@ def set(self, name, spec):
)
self._specs[name] = spec

def unsqueeze(self, dim: int):
if dim < 0:
new_dim = dim + len(self.shape) + 1
else:
new_dim = dim
if new_dim > len(self.shape) or new_dim < 0:
raise ValueError(f"Cannot unsqueeze along dim {dim}.")
new_stack_dim = self.dim if self.dim < new_dim else self.dim + 1
if new_dim > self.dim:
# unsqueeze 2, stack is on 1 => unsqueeze 1, stack along 1
new_stack_dim = self.dim
new_dim = new_dim - 1
else:
# unsqueeze 0, stack is on 1 => unsqueeze 0, stack on 1
new_stack_dim = self.dim + 1
return LazyStackedCompositeSpec(
*[spec.unsqueeze(new_dim) for spec in self._specs], dim=new_stack_dim
)

def squeeze(self, dim: int=None):
if dim is None:
size = self.shape
if len(size) == 1 or size.count(1) == 0:
return self
first_singleton_dim = size.index(1)

squeezed_dict = self.squeeze(first_singleton_dim)
return squeezed_dict.squeeze(dim=None)

if dim < 0:
new_dim = self.ndim + dim
else:
new_dim = dim

if self.shape and (new_dim >= self.ndim or new_dim < 0):
raise RuntimeError(
f"squeezing is allowed for dims comprised between 0 and "
f"spec.ndim only. Got dim={dim} and shape"
f"={self.shape}."
)

if new_dim >= self.ndim or self.shape[new_dim] != 1:
return self

if new_dim == self.dim:
return self._specs[0]
if new_dim > self.dim:
# squeeze 2, stack is on 1 => squeeze 1, stack along 1
new_stack_dim = self.dim
new_dim = new_dim - 1
else:
# squeeze 0, stack is on 1 => squeeze 0, stack on 1
new_stack_dim = self.dim - 1
return LazyStackedCompositeSpec(
*[spec.squeeze(new_dim) for spec in self._specs], dim=new_stack_dim
)


# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]:
@TensorSpec.implements_for_spec(torch.stack)
Expand Down