diff --git a/test/test_specs.py b/test/test_specs.py index fbc790fa872..aaf30a0906e 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -10,13 +10,14 @@ import torchrl.data.tensor_specs from _utils_internal import get_available_devices, set_global_var from scipy.stats import chisquare -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, BinaryDiscreteTensorSpec, BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, + LazyStackedCompositeSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -1703,6 +1704,642 @@ def test_unboundeddiscrete( assert spec is not spec.clone() +@pytest.mark.parametrize( + "shape,stack_dim", + [[(), 0], [(2,), 0], [(2,), 1], [(2, 3), 0], [(2, 3), 1], [(2, 3), 2]], +) +class TestStack: + def test_stack_binarydiscrete(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, BinaryDiscreteTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_binarydiscrete_expand(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_binarydiscrete_rand(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_binarydiscrete_zero(self, shape, stack_dim): + n = 5 + shape = (*shape, n) + c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_bounded(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, BoundedTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_bounded_expand(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_bounded_rand(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_bounded_zero(self, shape, stack_dim): + mini = -1 + maxi = 1 + shape = (*shape,) + c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_discrete(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, DiscreteTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_discrete_expand(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_discrete_rand(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_discrete_zero(self, shape, stack_dim): + n = 4 + shape = (*shape,) + c1 = DiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_multidiscrete(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, MultiDiscreteTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_multidiscrete_expand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_multidiscrete_rand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_multidiscrete_zero(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 2) + c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_multionehot(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, MultiOneHotDiscreteTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_multionehot_expand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_multionehot_rand(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_multionehot_zero(self, shape, stack_dim): + nvec = [4, 5] + shape = (*shape, 9) + c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_onehot(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, OneHotDiscreteTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_onehot_expand(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_onehot_rand(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_onehot_zero(self, shape, stack_dim): + n = 5 + shape = (*shape, 5) + c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_unboundedcont(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedContinuousTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, UnboundedContinuousTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_unboundedcont_expand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedContinuousTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_unboundedcont_rand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedContinuousTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.rand() + assert r.shape == c.shape + + def test_stack_unboundedcont_zero(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + r = c.zero() + assert r.shape == c.shape + + def test_stack_unboundeddiscrete(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, UnboundedDiscreteTensorSpec) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + assert c.shape == torch.Size(shape) + + def test_stack_unboundeddiscrete_expand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + shape = list(shape) + if stack_dim < 0: + stack_dim = len(shape) + stack_dim + 1 + shape.insert(stack_dim, 2) + cexpand = c.expand(3, 2, *shape) + assert cexpand.shape == torch.Size([3, 2, *shape]) + + def test_stack_unboundeddiscrete_rand(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + r = c.rand() + assert r.shape == c.shape + + def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): + shape = (*shape,) + c1 = UnboundedDiscreteTensorSpec(shape=shape) + c2 = c1.clone() + c = torch.stack([c1, c2], stack_dim) + r = c.zero() + assert r.shape == c.shape + + def test_to_numpy(self, shape, stack_dim): + c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c = torch.stack([c1, c2], stack_dim) + + shape = list(shape) + shape.insert(stack_dim, 2) + shape = tuple(shape) + + val = 2 * torch.rand(torch.Size(shape)) - 1 + + val_np = c.to_numpy(val) + assert isinstance(val_np, np.ndarray) + assert (val.numpy() == val_np).all() + + with pytest.raises(AssertionError): + c.to_numpy(val + 1) + + +class TestStackComposite: + def test_stack(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) + c2 = c1.clone() + c = torch.stack([c1, c2], 0) + assert isinstance(c, CompositeSpec) + + def test_stack_index(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec() + ) + c = torch.stack([c1, c2], 0) + assert c.shape == torch.Size([2]) + assert c[0] is c1 + assert c[1] is c2 + assert c[..., 0] is c1 + assert c[..., 1] is c2 + assert c[0, ...] is c1 + assert c[1, ...] is c2 + assert isinstance(c[:], LazyStackedCompositeSpec) + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_index_multdim(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + if stack_dim in (0, -3): + assert isinstance(c[:], LazyStackedCompositeSpec) + assert c.shape == torch.Size([2, 1, 3]) + assert c[0] is c1 + assert c[1] is c2 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 0] is c1 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 1] is c2 + assert c[0, ...] is c1 + assert c[1, ...] is c2 + elif stack_dim == (1, -2): + assert isinstance(c[:, :], LazyStackedCompositeSpec) + assert c.shape == torch.Size([1, 2, 3]) + assert c[:, 0] is c1 + assert c[:, 1] is c2 + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 1." + ): + assert c[0] is c1 + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 1." + ): + assert c[1] is c1 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 0] is c1 + with pytest.raises( + IndexError, + match="only permitted if the stack dimension is the last dimension", + ): + assert c[..., 1] is c2 + assert c[..., 0, :] is c1 + assert c[..., 1, :] is c2 + assert c[:, 0, ...] is c1 + assert c[:, 1, ...] is c2 + elif stack_dim == (2, -1): + assert isinstance(c[:, :, :], LazyStackedCompositeSpec) + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 2." + ): + assert c[0] is c1 + with pytest.raises( + IndexError, match="along dimension 0 when the stack dimension is 2." + ): + assert c[1] is c1 + assert c.shape == torch.Size([1, 3, 2]) + assert c[:, :, 0] is c1 + assert c[:, :, 1] is c2 + assert c[..., 0] is c1 + assert c[..., 1] is c2 + assert c[:, :, 0, ...] is c1 + assert c[:, :, 1, ...] is c2 + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_expand_multi(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + if stack_dim in (0, -3): + c_expand = c.expand([4, 2, 1, 3]) + assert c_expand.shape == torch.Size([4, 2, 1, 3]) + assert c_expand.dim == 1 + elif stack_dim in (1, -2): + c_expand = c.expand([4, 1, 2, 3]) + assert c_expand.shape == torch.Size([4, 1, 2, 3]) + assert c_expand.dim == 2 + elif stack_dim in (2, -1): + c_expand = c.expand( + [ + 4, + 1, + 3, + 2, + ] + ) + assert c_expand.shape == torch.Size([4, 1, 3, 2]) + assert c_expand.dim == 3 + else: + raise NotImplementedError + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_rand(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + r = c.rand() + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([2, 1, 3]) + assert r["a"].shape == torch.Size([2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([1, 2, 3]) + assert r["a"].shape == torch.Size([1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([1, 3, 2]) + assert r["a"].shape == torch.Size([1, 3, 2]) # access tensor + assert (r["a"] != 0).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_rand_shape(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + shape = [5, 6] + r = c.rand(shape) + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([*shape, 2, 1, 3]) + assert r["a"].shape == torch.Size([*shape, 2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([*shape, 1, 2, 3]) + assert r["a"].shape == torch.Size([*shape, 1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([*shape, 1, 3, 2]) + assert r["a"].shape == torch.Size([*shape, 1, 3, 2]) # access tensor + assert (r["a"] != 0).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_zero(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + r = c.zero() + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([2, 1, 3]) + assert r["a"].shape == torch.Size([2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([1, 2, 3]) + assert r["a"].shape == torch.Size([1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([1, 3, 2]) + assert r["a"].shape == torch.Size([1, 3, 2]) # access tensor + assert (r["a"] == 0).all() + + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_stack_zero_shape(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + shape = [5, 6] + r = c.zero(shape) + assert isinstance(r, LazyStackedTensorDict) + if stack_dim in (0, -3): + assert r.shape == torch.Size([*shape, 2, 1, 3]) + assert r["a"].shape == torch.Size([*shape, 2, 1, 3]) # access tensor + elif stack_dim in (1, -2): + assert r.shape == torch.Size([*shape, 1, 2, 3]) + assert r["a"].shape == torch.Size([*shape, 1, 2, 3]) # access tensor + elif stack_dim in (2, -1): + assert r.shape == torch.Size([*shape, 1, 3, 2]) + assert r["a"].shape == torch.Size([*shape, 1, 3, 2]) # access tensor + assert (r["a"] == 0).all() + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") + @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) + def test_to(self, stack_dim): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], stack_dim) + assert isinstance(c, LazyStackedCompositeSpec) + cdevice = c.to("cuda:0") + assert cdevice.device != c.device + assert cdevice.device == torch.device("cuda:0") + if stack_dim < 0: + stack_dim += 3 + index = (slice(None),) * stack_dim + (0,) + assert cdevice[index].device == torch.device("cuda:0") + + def test_clone(self): + c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=UnboundedContinuousTensorSpec(shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 0) + cclone = c.clone() + assert cclone[0] is not c[0] + assert cclone[0] == c[0] + + def test_to_numpy(self): + c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = CompositeSpec( + a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + b=UnboundedDiscreteTensorSpec(shape=(1, 3)), + shape=(1, 3), + ) + c = torch.stack([c1, c2], 0) + for _ in range(100): + r = c.rand() + for key, value in c.to_numpy(r).items(): + spec = c[key] + assert (spec.to_numpy(r[key]) == value).all() + + td_fail = TensorDict({"a": torch.rand((2, 1, 3)) + 1}, [2, 1, 3]) + with pytest.raises(AssertionError): + c.to_numpy(td_fail) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 4a0ac554218..788a2cce27d 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -21,6 +21,8 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, + LazyStackedCompositeSpec, + LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c075ed2b71d..54ea17fb4e2 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -9,16 +9,20 @@ import warnings from copy import deepcopy from dataclasses import dataclass +from functools import wraps from textwrap import indent from typing import ( Any, + Callable, Dict, + Generic, ItemsView, KeysView, List, Optional, Sequence, Tuple, + TypeVar, Union, ValuesView, ) @@ -233,6 +237,19 @@ class TensorSpec: dtype: torch.dtype = torch.float domain: str = "" + SPEC_HANDLED_FUNCTIONS = {} + + @classmethod + def implements_for_spec(cls, torch_function: Callable) -> Callable: + """Register a torch function override for TensorSpec.""" + + @wraps(torch_function) + def decorator(func): + cls.SPEC_HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: """Encodes a value given the specified spec, and return the corresponding tensor. @@ -329,6 +346,24 @@ def expand(self, *shape): """ raise NotImplementedError + def squeeze(self, dim: int | None = None): + """Returns a new Spec with all the dimensions of size ``1`` removed. + + When ``dim`` is given, a squeeze operation is done only in that dimension. + + Args: + dim (int or None): the dimension to apply the squeeze operation to + + """ + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + + def unsqueeze(self, dim: int): + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + def _project(self, val: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -436,6 +471,260 @@ def __repr__(self): string = f"{self.__class__.__name__}(\n {sub_string})" return string + @classmethod + def __torch_function__( + cls, + func: Callable, + types, + args: Tuple = (), + kwargs: Optional[dict] = None, + ) -> Callable: + if kwargs is None: + kwargs = {} + if func not in cls.SPEC_HANDLED_FUNCTIONS or not all( + issubclass(t, (TensorSpec,)) for t in types + ): + return NotImplemented( + f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}" + ) + return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) + + +T = TypeVar("T") + + +class _LazyStackedMixin(Generic[T]): + def __init__(self, *specs: tuple[T, ...], dim: int) -> None: + self._specs = specs + self.dim = dim + if self.dim < 0: + self.dim = len(self.shape) + self.dim + + def __getitem__(self, item): + is_key = isinstance(item, str) or ( + isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) + ) + if is_key: + return torch.stack( + [composite_spec[item] for composite_spec in self._specs], dim=self.dim + ) + elif isinstance(item, tuple): + # quick check that the index is along the stacked dim + # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs + if item[0] is Ellipsis: + if len(item) == 1: + return self + elif self.dim == len(self.shape) - 1 and len(item) == 2: + # we can return + return self._specs[item[1]] + elif len(item) > 2: + # check that there is only one non-slice index + assigned = False + dim_idx = self.dim + for i, _item in enumerate(item[1:]): + if ( + isinstance(_item, slice) + and not ( + _item.start is None + and _item.stop is None + and _item.step is None + ) + ) or not isinstance(_item, slice): + if assigned: + raise RuntimeError( + "Found more than one meaningful index in a stacked composite spec." + ) + item = _item + dim_idx = i + 1 + assigned = True + if not assigned: + return self + if dim_idx != self.dim: + raise RuntimeError( + f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}." + ) + out = self._specs[item] + if isinstance(out, TensorSpec): + return out + return torch.stack(list(out), 0) + else: + raise IndexError( + f"Indexing a {self.__class__.__name__} with [..., idx] is only permitted if the stack dimension is the last dimension. " + f"Got self.dim={self.dim} and self.shape={self.shape}." + ) + elif len(item) >= 2 and item[-1] is Ellipsis: + return self[item[:-1]] + elif any(_item is Ellipsis for _item in item): + raise IndexError("Cannot index along multiple dimensions.") + # Ellipsis is now ruled out + elif any(_item is None for _item in item): + raise IndexError( + f"Cannot index a {self.__class__.__name__} with None values" + ) + # Must be an index with slices then + else: + for i, _item in enumerate(item): + if i == self.dim: + out = self._specs[_item] + if isinstance(out, TensorSpec): + return out + return torch.stack(list(out), 0) + elif isinstance(_item, slice): + # then the slice must be trivial + if not (_item.step is _item.start is _item.stop is None): + raise IndexError( + f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." + ) + else: + return self + else: + if not self.dim == 0: + raise IndexError( + f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}." + ) + out = self._specs[item] + if isinstance(out, TensorSpec): + return out + return torch.stack(list(out), 0) + + @property + def shape(self): + shape = list(self._specs[0].shape) + dim = self.dim + if dim < 0: + dim = len(shape) + dim + 1 + shape.insert(dim, len(self._specs)) + return torch.Size(shape) + + def clone(self) -> T: + return torch.stack([spec.clone() for spec in self._specs], 0) + + def expand(self, *shape): + if len(shape) == 1 and not isinstance(shape[0], (int,)): + return self.expand(*shape[0]) + expand_shape = shape[: -len(self.shape)] + existing_shape = self.shape + shape_check = shape[-len(self.shape) :] + for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): + if size1 != size2 and size1 != 1: + raise RuntimeError( + f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" + ) + elif size1 != size2 and size1 == 1 and _i == self.dim: + # if we're expanding along the stack dim we just need to clone the existing spec + return torch.stack( + [self._specs[0].clone() for _ in range(size2)], self.dim + ).expand(*shape) + if _i != len(self.shape) - 1: + raise RuntimeError( + f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." + ) + # remove the stack dim from the expanded shape, which we know to match + unstack_shape = list(expand_shape) + [ + s for i, s in enumerate(shape_check) if i != self.dim + ] + return torch.stack( + [spec.expand(unstack_shape) for spec in self._specs], + self.dim + len(expand_shape), + ) + + def zero(self, shape=None) -> TensorDictBase: + if shape is not None: + dim = self.dim + len(shape) + else: + dim = self.dim + return torch.stack([spec.zero(shape) for spec in self._specs], dim) + + def rand(self, shape=None) -> TensorDictBase: + if shape is not None: + dim = self.dim + len(shape) + else: + dim = self.dim + return torch.stack([spec.rand(shape) for spec in self._specs], dim) + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: + return torch.stack([spec.to(dest) for spec in self._specs], self.dim) + + +class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): + """A lazy representation of a stack of tensor specs. + + Stacks tensor-specs together along one dimension. + When random samples are drawn, a stack of samples is returned if possible. + If not, an error is thrown. + + Indexing is allowed but only along the stack dimension. + + This class is aimed to be used in multi-task and multi-agent settings, where + heterogeneous specs may occur (same semantic but different shape). + + """ + + @property + def space(self): + return self._specs[0].space + + def __eq__(self, other): + # requires unbind to be implemented + pass + + def to_numpy(self, val: torch.Tensor, safe: bool = True) -> dict: + if safe: + if val.shape[self.dim] != len(self._specs): + raise ValueError( + "Size of LazyStackedTensorSpec and val differ along the stacking " + "dimension" + ) + for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): + spec.assert_is_in(v) + return val.detach().cpu().numpy() + + def __len__(self): + pass + + def project(self, val: TensorDictBase) -> TensorDictBase: + pass + + def __repr__(self): + shape_str = "shape=" + str(self.shape) + space_str = "space=" + str(self._specs[0].space) + device_str = "device=" + str(self.device) + dtype_str = "dtype=" + str(self.dtype) + domain_str = "domain=" + str(self._specs[0].domain) + sub_string = ", ".join( + [shape_str, space_str, device_str, dtype_str, domain_str] + ) + string = f"{self.__class__.__name__}(\n {sub_string})" + return string + + def __iter__(self): + pass + + def __setitem__(self, key, value): + pass + + @property + def device(self) -> DEVICE_TYPING: + return self._specs[0].device + + @property + def ndim(self): + return self.ndimension() + + def ndimension(self): + return len(self.shape) + + def set(self, name, spec): + if spec is not None: + shape = spec.shape + if shape[: self.ndim] != self.shape: + raise ValueError( + "The shape of the spec and the CompositeSpec mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"CompositeSpec.shape={self.shape}." + ) + self._specs[name] = spec + @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): @@ -476,6 +765,8 @@ class OneHotDiscreteTensorSpec(TensorSpec): dtype: torch.dtype = torch.float domain: str = "" + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, n: int, @@ -487,9 +778,7 @@ def __init__( dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register - space = DiscreteBox( - n, - ) + space = DiscreteBox(n) if shape is None: shape = torch.Size((space.n,)) else: @@ -543,6 +832,39 @@ def expand(self, *shape): n=shape[-1], shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim=None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of OneHotDiscreteTensorSpec must remain unchanged" + ) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + return self.__class__( + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + use_register=self.use_register, + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of OneHotDiscreteTensorSpec must remain unchanged" + ) + + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + n=shape[-1], + shape=shape, + device=self.device, + dtype=self.dtype, + use_register=self.use_register, + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] @@ -657,6 +979,8 @@ class BoundedTensorSpec(TensorSpec): """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, minimum: Union[float, torch.Tensor, np.ndarray], @@ -756,6 +1080,36 @@ def expand(self, *shape): dtype=self.dtype, ) + def squeeze(self, dim: int | None = None): + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + if dim is None: + minimum = self.space.minimum.squeeze().clone() + maximum = self.space.maximum.squeeze().clone() + else: + minimum = self.space.minimum.squeeze(dim).clone() + maximum = self.space.maximum.squeeze(dim).clone() + + return self.__class__( + minimum=minimum, + maximum=maximum, + shape=shape, + device=self.device, + dtype=self.dtype, + ) + + def unsqueeze(self, dim: int): + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + minimum=self.space.minimum.unsqueeze(dim).clone(), + maximum=self.space.maximum.unsqueeze(dim).clone(), + shape=shape, + device=self.device, + dtype=self.dtype, + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -853,6 +1207,8 @@ class UnboundedContinuousTensorSpec(TensorSpec): (should be an floating point dtype such as float, double etc.) """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -924,6 +1280,8 @@ class UnboundedDiscreteTensorSpec(TensorSpec): (should be an integer dtype such as long, uint8 etc.) """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, @@ -1021,6 +1379,8 @@ class BinaryDiscreteTensorSpec(TensorSpec): dtype: torch.dtype = torch.float domain: str = "" + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, n: int, @@ -1030,7 +1390,7 @@ def __init__( ): dtype, device = _default_dtype_and_device(dtype, device) box = BinaryBox(n) - if shape is None: + if shape is None or not len(shape): shape = torch.Size((n,)) else: shape = torch.Size(shape) @@ -1077,6 +1437,28 @@ def expand(self, *shape): n=shape[-1], shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim: int | None = None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of BinaryDiscreteTensorSpec must remain unchanged" + ) + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__( + n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of BinaryDiscreteTensorSpec must remain unchanged" + ) + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -1122,6 +1504,8 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, nvec: Sequence[int], @@ -1285,6 +1669,29 @@ def expand(self, *shape): nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim=None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of MultiOneHotDiscreteTensorSpec must remain unchanged" + ) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__( + nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of MultiOneHotDiscreteTensorSpec must remain unchanged" + ) + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + ) + class DiscreteTensorSpec(TensorSpec): """A discrete tensor spec. @@ -1316,6 +1723,8 @@ class DiscreteTensorSpec(TensorSpec): dtype: torch.dtype = torch.float domain: str = "" + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, n: int, @@ -1399,6 +1808,20 @@ def expand(self, *shape): n=self.space.n, shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim=None): + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + return self.__class__( + n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + shape = _unsqueezed_shape(self.shape, dim) + return self.__class__( + n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest @@ -1442,6 +1865,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): False """ + # SPEC_HANDLED_FUNCTIONS = {} + def __init__( self, nvec: Union[Sequence[int], torch.Tensor, int], @@ -1602,6 +2027,36 @@ def expand(self, *shape): nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype ) + def squeeze(self, dim: int | None = None): + if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): + raise ValueError( + "Final dimension of MultiDiscreteTensorSpec must remain unchanged" + ) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + if dim is None: + nvec = self.nvec.squeeze() + else: + nvec = self.nvec.squeeze(dim) + + return self.__class__( + nvec=nvec, shape=shape, device=self.device, dtype=self.dtype + ) + + def unsqueeze(self, dim: int): + if dim in (len(self.shape), -1): + raise ValueError( + "Final dimension of MultiDiscreteTensorSpec must remain unchanged" + ) + shape = _unsqueezed_shape(self.shape, dim) + nvec = self.nvec.unsqueeze(dim) + return self.__class__( + nvec=nvec, shape=shape, device=self.device, dtype=self.dtype + ) + class CompositeSpec(TensorSpec): """A composition of TensorSpecs. @@ -1670,6 +2125,8 @@ class CompositeSpec(TensorSpec): shape: torch.Size domain: str = "composite" + SPEC_HANDLED_FUNCTIONS = {} + @classmethod def __new__(cls, *args, **kwargs): cls._device = torch.device("cpu") @@ -1720,7 +2177,7 @@ def __init__(self, *args, shape=None, device=None, **kwargs): for key, value in kwargs.items(): self.set(key, value) - _device = device + _device = torch.device(device) if device is not None else device if len(kwargs): for key, item in self.items(): if item is None: @@ -1739,8 +2196,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs): _device = item_device elif item_device != _device: raise RuntimeError( - f"Setting a new attribute ({key}) on another device ({item.device} against {_device}). " - f"All devices of CompositeSpec must match." + f"Setting a new attribute ({key}) on another device " + f"({item.device} against {_device}). All devices of " + "CompositeSpec must match." ) self._device = _device if len(args): @@ -1754,11 +2212,10 @@ def __init__(self, *args, shape=None, device=None, **kwargs): f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." ) for k, item in argdict.items(): - if item is None: - continue - if self._device is None: - self._device = item.device - self[k] = item + if item is not None: + if self._device is None: + self._device = item.device + self[k] = item @property def device(self) -> DEVICE_TYPING: @@ -1904,7 +2361,7 @@ def rand(self, shape=None) -> TensorDictBase: } return TensorDict( _dict, - batch_size=shape, + batch_size=[*shape, *self.shape], device=self._device, ) @@ -1929,9 +2386,7 @@ def keys( Default is ``False``. """ return _CompositeSpecKeysView( - self, - include_nested=include_nested, - leaves_only=leaves_only, + self, include_nested=include_nested, leaves_only=leaves_only ) def items(self) -> ItemsView: @@ -2006,7 +2461,9 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N continue try: if isinstance(item, TensorSpec) and item.device != self.device: - item = deepcopy(item).to(self.device) + item = deepcopy(item) + if self.device is not None: + item = item.to(self.device) except RuntimeError as err: if DEVICE_ERR_MSG in str(err): try: @@ -2046,6 +2503,247 @@ def expand(self, *shape): ) return out + def squeeze(self, dim: int | None = None): + if dim is not None: + if dim < 0: + dim += len(self.shape) + + shape = _squeezed_shape(self.shape, dim) + if shape is None: + return self + + try: + device = self.device + except RuntimeError: + device = self._device + + return CompositeSpec( + {key: value.squeeze(dim) for key, value in self.items()}, + shape=shape, + device=device, + ) + + if self.shape.count(1) == 0: + return self + + # we can't just recursively apply squeeze with dim=None because we don't want + # to squeeze non-batch dims of the values. Instead we find the first dim in + # the batch dims with size 1, squeeze that, then recurse on the root spec + out = self.squeeze(self.shape.index(1)) + return out.squeeze() + + def unsqueeze(self, dim: int): + if dim < 0: + dim += len(self.shape) + + shape = _unsqueezed_shape(self.shape, dim) + + try: + device = self.device + except RuntimeError: + device = self._device + + return CompositeSpec( + {key: value.unsqueeze(dim) for key, value in self.items()}, + shape=shape, + device=device, + ) + + +class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): + """A lazy representation of a stack of composite specs. + + Stacks composite specs together along one dimension. + When random samples are drawn, a LazyStackedTensorDict is returned. + + Indexing is allowed but only along the stack dimension. + + This class is aimed to be used in multi-task and multi-agent settings, where + heterogeneous specs may occur (same semantic but different shape). + + """ + + def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: + pass + + def __eq__(self, other): + pass + + def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: + if safe: + if val.shape[self.dim] != len(self._specs): + raise ValueError( + "Size of LazyStackedCompositeSpec and val differ along the " + "stacking dimension" + ) + for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): + spec.assert_is_in(v) + return {key: self[key].to_numpy(val) for key, val in val.items()} + + def __len__(self): + pass + + def values(self): + for key in self.keys(): + yield self[key] + + def items(self): + for key in self.keys(): + yield key, self[key] + + def keys( + self, + include_nested: bool = False, + leaves_only: bool = False, + ) -> KeysView: + return self._specs[0].keys( + include_nested=include_nested, leaves_only=leaves_only + ) + + def project(self, val: TensorDictBase) -> TensorDictBase: + pass + + def is_in(self, val: Union[dict, TensorDictBase]) -> bool: + pass + + def type_check( + self, + value: Union[torch.Tensor, TensorDictBase], + selected_keys: Union[str, Optional[Sequence[str]]] = None, + ): + pass + + def __repr__(self) -> str: + sub_str = ",\n".join( + [indent(f"{k}: {repr(item)}", 4 * " ") for k, item in self.items()] + ) + device_str = f"device={self._specs[0].device}" + shape_str = f"shape={self.shape}" + sub_str = ", ".join([sub_str, device_str, shape_str]) + return ( + f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" + ) + + def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: + pass + + def __delitem__(self, key): + pass + + def __iter__(self): + pass + + def __setitem__(self, key, value): + pass + + @property + def device(self) -> DEVICE_TYPING: + return self._specs[0].device + + @property + def ndim(self): + return self.ndimension() + + def ndimension(self): + return len(self.shape) + + def set(self, name, spec): + if spec is not None: + shape = spec.shape + if shape[: self.ndim] != self.shape: + raise ValueError( + "The shape of the spec and the CompositeSpec mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"CompositeSpec.shape={self.shape}." + ) + self._specs[name] = spec + + +# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: +@TensorSpec.implements_for_spec(torch.stack) +def _stack_specs(list_of_spec, dim, out=None): + if out is not None: + raise NotImplementedError( + "In-place spec modification is not a feature of torchrl, hence " + "torch.stack(list_of_specs, dim, out=spec) is not implemented." + ) + if not len(list_of_spec): + raise ValueError("Cannot stack an empty list of specs.") + spec0 = list_of_spec[0] + if isinstance(spec0, TensorSpec): + device = spec0.device + all_equal = True + for spec in list_of_spec[1:]: + if not isinstance(spec, TensorSpec): + raise RuntimeError( + "Stacking specs cannot occur: Found more than one type of specs in the list." + ) + if device != spec.device: + raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + all_equal = all_equal and spec == spec0 + if all_equal: + shape = list(spec0.shape) + if dim < 0: + dim += len(shape) + 1 + shape.insert(dim, len(list_of_spec)) + return spec0.clone().unsqueeze(dim).expand(shape) + return LazyStackedTensorSpec(*list_of_spec, dim=dim) + else: + raise NotImplementedError + + +@CompositeSpec.implements_for_spec(torch.stack) +def _stack_composite_specs(list_of_spec, dim, out=None): + if out is not None: + raise NotImplementedError( + "In-place spec modification is not a feature of torchrl, hence " + "torch.stack(list_of_specs, dim, out=spec) is not implemented." + ) + if not len(list_of_spec): + raise ValueError("Cannot stack an empty list of specs.") + spec0 = list_of_spec[0] + if isinstance(spec0, CompositeSpec): + device = spec0.device + all_equal = True + for spec in list_of_spec[1:]: + if not isinstance(spec, CompositeSpec): + raise RuntimeError( + "Stacking specs cannot occur: Found more than one type of spec in " + "the list." + ) + if device != spec.device: + raise RuntimeError(f"Devices differ, got {device} and {spec.device}") + all_equal = all_equal and spec == spec0 + if all_equal: + shape = list(spec0.shape) + if dim < 0: + dim += len(shape) + 1 + shape.insert(dim, len(list_of_spec)) + return spec0.clone().unsqueeze(dim).expand(shape) + return LazyStackedCompositeSpec(*list_of_spec, dim=dim) + else: + raise NotImplementedError + + +@TensorSpec.implements_for_spec(torch.squeeze) +def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: + return spec.squeeze(*args, **kwargs) + + +@CompositeSpec.implements_for_spec(torch.squeeze) +def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: + return spec.squeeze(*args, **kwargs) + + +@TensorSpec.implements_for_spec(torch.unsqueeze) +def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: + return spec.unsqueeze(*args, **kwargs) + + +@CompositeSpec.implements_for_spec(torch.unsqueeze) +def _unsqueeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: + return spec.unsqueeze(*args, **kwargs) + def _keys_to_empty_composite_spec(keys): """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value.""" @@ -2071,6 +2769,37 @@ def _keys_to_empty_composite_spec(keys): return c +def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: + if dim is None: + if len(shape) == 1 or shape.count(1) == 0: + return None + new_shape = torch.Size([s for s in shape if s != 1]) + else: + if dim < 0: + dim += len(shape) + + if shape[dim] != 1: + return None + + new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim]) + return new_shape + + +def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: + n = len(shape) + if dim < -(n + 1) or dim > n: + raise ValueError( + f"Dimension out of range, expected value in the range [{-(n+1)}, {n}], but " + f"got {dim}" + ) + if dim < 0: + dim += n + 1 + + new_shape = list(shape) + new_shape.insert(dim, 1) + return torch.Size(new_shape) + + class _CompositeSpecKeysView: """Wrapper class that enables richer behaviour of `key in tensordict.keys()`.""" @@ -2084,9 +2813,7 @@ def __init__( self.leaves_only = leaves_only self.include_nested = include_nested - def __iter__( - self, - ): + def __iter__(self): for key, item in self.composite.items(): if self.include_nested and isinstance(item, CompositeSpec): for subkey in item.keys(