Skip to content

Commit c1acefd

Browse files
authored
[Feature] Contiguous stacking of matching specs (#960)
1 parent a912a2e commit c1acefd

File tree

2 files changed

+388
-77
lines changed

2 files changed

+388
-77
lines changed

test/test_specs.py

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
CompositeSpec,
1919
DiscreteTensorSpec,
2020
LazyStackedCompositeSpec,
21-
LazyStackedTensorSpec,
2221
MultiDiscreteTensorSpec,
2322
MultiOneHotDiscreteTensorSpec,
2423
OneHotDiscreteTensorSpec,
@@ -1716,7 +1715,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim):
17161715
c1 = BinaryDiscreteTensorSpec(n=n, shape=shape)
17171716
c2 = c1.clone()
17181717
c = torch.stack([c1, c2], stack_dim)
1719-
assert isinstance(c, LazyStackedTensorSpec)
1718+
assert isinstance(c, BinaryDiscreteTensorSpec)
17201719
shape = list(shape)
17211720
if stack_dim < 0:
17221721
stack_dim = len(shape) + stack_dim + 1
@@ -1761,7 +1760,7 @@ def test_stack_bounded(self, shape, stack_dim):
17611760
c1 = BoundedTensorSpec(mini, maxi, shape=shape)
17621761
c2 = c1.clone()
17631762
c = torch.stack([c1, c2], stack_dim)
1764-
assert isinstance(c, LazyStackedTensorSpec)
1763+
assert isinstance(c, BoundedTensorSpec)
17651764
shape = list(shape)
17661765
if stack_dim < 0:
17671766
stack_dim = len(shape) + stack_dim + 1
@@ -1808,7 +1807,7 @@ def test_stack_discrete(self, shape, stack_dim):
18081807
c1 = DiscreteTensorSpec(n, shape=shape)
18091808
c2 = c1.clone()
18101809
c = torch.stack([c1, c2], stack_dim)
1811-
assert isinstance(c, LazyStackedTensorSpec)
1810+
assert isinstance(c, DiscreteTensorSpec)
18121811
shape = list(shape)
18131812
if stack_dim < 0:
18141813
stack_dim = len(shape) + stack_dim + 1
@@ -1852,7 +1851,7 @@ def test_stack_multidiscrete(self, shape, stack_dim):
18521851
c1 = MultiDiscreteTensorSpec(nvec, shape=shape)
18531852
c2 = c1.clone()
18541853
c = torch.stack([c1, c2], stack_dim)
1855-
assert isinstance(c, LazyStackedTensorSpec)
1854+
assert isinstance(c, MultiDiscreteTensorSpec)
18561855
shape = list(shape)
18571856
if stack_dim < 0:
18581857
stack_dim = len(shape) + stack_dim + 1
@@ -1896,7 +1895,7 @@ def test_stack_multionehot(self, shape, stack_dim):
18961895
c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape)
18971896
c2 = c1.clone()
18981897
c = torch.stack([c1, c2], stack_dim)
1899-
assert isinstance(c, LazyStackedTensorSpec)
1898+
assert isinstance(c, MultiOneHotDiscreteTensorSpec)
19001899
shape = list(shape)
19011900
if stack_dim < 0:
19021901
stack_dim = len(shape) + stack_dim + 1
@@ -1940,7 +1939,7 @@ def test_stack_onehot(self, shape, stack_dim):
19401939
c1 = OneHotDiscreteTensorSpec(n, shape=shape)
19411940
c2 = c1.clone()
19421941
c = torch.stack([c1, c2], stack_dim)
1943-
assert isinstance(c, LazyStackedTensorSpec)
1942+
assert isinstance(c, OneHotDiscreteTensorSpec)
19441943
shape = list(shape)
19451944
if stack_dim < 0:
19461945
stack_dim = len(shape) + stack_dim + 1
@@ -1983,7 +1982,7 @@ def test_stack_unboundedcont(self, shape, stack_dim):
19831982
c1 = UnboundedContinuousTensorSpec(shape=shape)
19841983
c2 = c1.clone()
19851984
c = torch.stack([c1, c2], stack_dim)
1986-
assert isinstance(c, LazyStackedTensorSpec)
1985+
assert isinstance(c, UnboundedContinuousTensorSpec)
19871986
shape = list(shape)
19881987
if stack_dim < 0:
19891988
stack_dim = len(shape) + stack_dim + 1
@@ -2023,7 +2022,7 @@ def test_stack_unboundeddiscrete(self, shape, stack_dim):
20232022
c1 = UnboundedDiscreteTensorSpec(shape=shape)
20242023
c2 = c1.clone()
20252024
c = torch.stack([c1, c2], stack_dim)
2026-
assert isinstance(c, LazyStackedTensorSpec)
2025+
assert isinstance(c, UnboundedDiscreteTensorSpec)
20272026
shape = list(shape)
20282027
if stack_dim < 0:
20292028
stack_dim = len(shape) + stack_dim + 1
@@ -2064,11 +2063,13 @@ def test_stack(self):
20642063
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec())
20652064
c2 = c1.clone()
20662065
c = torch.stack([c1, c2], 0)
2067-
assert isinstance(c, LazyStackedCompositeSpec)
2066+
assert isinstance(c, CompositeSpec)
20682067

20692068
def test_stack_index(self):
20702069
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec())
2071-
c2 = c1.clone()
2070+
c2 = CompositeSpec(
2071+
a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec()
2072+
)
20722073
c = torch.stack([c1, c2], 0)
20732074
assert c.shape == torch.Size([2])
20742075
assert c[0] is c1
@@ -2082,7 +2083,11 @@ def test_stack_index(self):
20822083
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
20832084
def test_stack_index_multdim(self, stack_dim):
20842085
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2085-
c2 = c1.clone()
2086+
c2 = CompositeSpec(
2087+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2088+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2089+
shape=(1, 3),
2090+
)
20862091
c = torch.stack([c1, c2], stack_dim)
20872092
if stack_dim in (0, -3):
20882093
assert isinstance(c[:], LazyStackedCompositeSpec)
@@ -2146,36 +2151,14 @@ def test_stack_index_multdim(self, stack_dim):
21462151
assert c[:, :, 0, ...] is c1
21472152
assert c[:, :, 1, ...] is c2
21482153

2149-
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
2150-
def test_stack_expand_one(self, stack_dim):
2151-
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2152-
c = torch.stack([c1], stack_dim)
2153-
if stack_dim in (0, -3):
2154-
c_expand = c.expand([4, 2, 1, 3])
2155-
assert c_expand.shape == torch.Size([4, 2, 1, 3])
2156-
assert c_expand.dim == 1
2157-
elif stack_dim in (1, -2):
2158-
c_expand = c.expand([4, 1, 2, 3])
2159-
assert c_expand.shape == torch.Size([4, 1, 2, 3])
2160-
assert c_expand.dim == 2
2161-
elif stack_dim in (2, -1):
2162-
c_expand = c.expand(
2163-
[
2164-
4,
2165-
1,
2166-
3,
2167-
2,
2168-
]
2169-
)
2170-
assert c_expand.shape == torch.Size([4, 1, 3, 2])
2171-
assert c_expand.dim == 3
2172-
else:
2173-
raise NotImplementedError
2174-
21752154
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
21762155
def test_stack_expand_multi(self, stack_dim):
21772156
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2178-
c2 = c1.clone()
2157+
c2 = CompositeSpec(
2158+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2159+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2160+
shape=(1, 3),
2161+
)
21792162
c = torch.stack([c1, c2], stack_dim)
21802163
if stack_dim in (0, -3):
21812164
c_expand = c.expand([4, 2, 1, 3])
@@ -2202,7 +2185,11 @@ def test_stack_expand_multi(self, stack_dim):
22022185
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
22032186
def test_stack_rand(self, stack_dim):
22042187
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2205-
c2 = c1.clone()
2188+
c2 = CompositeSpec(
2189+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2190+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2191+
shape=(1, 3),
2192+
)
22062193
c = torch.stack([c1, c2], stack_dim)
22072194
r = c.rand()
22082195
assert isinstance(r, LazyStackedTensorDict)
@@ -2220,7 +2207,11 @@ def test_stack_rand(self, stack_dim):
22202207
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
22212208
def test_stack_rand_shape(self, stack_dim):
22222209
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2223-
c2 = c1.clone()
2210+
c2 = CompositeSpec(
2211+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2212+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2213+
shape=(1, 3),
2214+
)
22242215
c = torch.stack([c1, c2], stack_dim)
22252216
shape = [5, 6]
22262217
r = c.rand(shape)
@@ -2239,7 +2230,11 @@ def test_stack_rand_shape(self, stack_dim):
22392230
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
22402231
def test_stack_zero(self, stack_dim):
22412232
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2242-
c2 = c1.clone()
2233+
c2 = CompositeSpec(
2234+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2235+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2236+
shape=(1, 3),
2237+
)
22432238
c = torch.stack([c1, c2], stack_dim)
22442239
r = c.zero()
22452240
assert isinstance(r, LazyStackedTensorDict)
@@ -2257,7 +2252,11 @@ def test_stack_zero(self, stack_dim):
22572252
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
22582253
def test_stack_zero_shape(self, stack_dim):
22592254
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2260-
c2 = c1.clone()
2255+
c2 = CompositeSpec(
2256+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2257+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2258+
shape=(1, 3),
2259+
)
22612260
c = torch.stack([c1, c2], stack_dim)
22622261
shape = [5, 6]
22632262
r = c.zero(shape)
@@ -2274,18 +2273,31 @@ def test_stack_zero_shape(self, stack_dim):
22742273
assert (r["a"] == 0).all()
22752274

22762275
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
2277-
def test_to(self):
2276+
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
2277+
def test_to(self, stack_dim):
22782278
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2279-
c2 = c1.clone()
2279+
c2 = CompositeSpec(
2280+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2281+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2282+
shape=(1, 3),
2283+
)
22802284
c = torch.stack([c1, c2], stack_dim)
2285+
assert isinstance(c, LazyStackedCompositeSpec)
22812286
cdevice = c.to("cuda:0")
22822287
assert cdevice.device != c.device
22832288
assert cdevice.device == torch.device("cuda:0")
2284-
assert cdevice[0].device == torch.device("cuda:0")
2289+
if stack_dim < 0:
2290+
stack_dim += 3
2291+
index = (slice(None),) * stack_dim + (0,)
2292+
assert cdevice[index].device == torch.device("cuda:0")
22852293

22862294
def test_clone(self):
22872295
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
2288-
c2 = c1.clone()
2296+
c2 = CompositeSpec(
2297+
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
2298+
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
2299+
shape=(1, 3),
2300+
)
22892301
c = torch.stack([c1, c2], 0)
22902302
cclone = c.clone()
22912303
assert cclone[0] is not c[0]

0 commit comments

Comments
 (0)