1818 CompositeSpec ,
1919 DiscreteTensorSpec ,
2020 LazyStackedCompositeSpec ,
21- LazyStackedTensorSpec ,
2221 MultiDiscreteTensorSpec ,
2322 MultiOneHotDiscreteTensorSpec ,
2423 OneHotDiscreteTensorSpec ,
@@ -1716,7 +1715,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim):
17161715 c1 = BinaryDiscreteTensorSpec (n = n , shape = shape )
17171716 c2 = c1 .clone ()
17181717 c = torch .stack ([c1 , c2 ], stack_dim )
1719- assert isinstance (c , LazyStackedTensorSpec )
1718+ assert isinstance (c , BinaryDiscreteTensorSpec )
17201719 shape = list (shape )
17211720 if stack_dim < 0 :
17221721 stack_dim = len (shape ) + stack_dim + 1
@@ -1761,7 +1760,7 @@ def test_stack_bounded(self, shape, stack_dim):
17611760 c1 = BoundedTensorSpec (mini , maxi , shape = shape )
17621761 c2 = c1 .clone ()
17631762 c = torch .stack ([c1 , c2 ], stack_dim )
1764- assert isinstance (c , LazyStackedTensorSpec )
1763+ assert isinstance (c , BoundedTensorSpec )
17651764 shape = list (shape )
17661765 if stack_dim < 0 :
17671766 stack_dim = len (shape ) + stack_dim + 1
@@ -1808,7 +1807,7 @@ def test_stack_discrete(self, shape, stack_dim):
18081807 c1 = DiscreteTensorSpec (n , shape = shape )
18091808 c2 = c1 .clone ()
18101809 c = torch .stack ([c1 , c2 ], stack_dim )
1811- assert isinstance (c , LazyStackedTensorSpec )
1810+ assert isinstance (c , DiscreteTensorSpec )
18121811 shape = list (shape )
18131812 if stack_dim < 0 :
18141813 stack_dim = len (shape ) + stack_dim + 1
@@ -1852,7 +1851,7 @@ def test_stack_multidiscrete(self, shape, stack_dim):
18521851 c1 = MultiDiscreteTensorSpec (nvec , shape = shape )
18531852 c2 = c1 .clone ()
18541853 c = torch .stack ([c1 , c2 ], stack_dim )
1855- assert isinstance (c , LazyStackedTensorSpec )
1854+ assert isinstance (c , MultiDiscreteTensorSpec )
18561855 shape = list (shape )
18571856 if stack_dim < 0 :
18581857 stack_dim = len (shape ) + stack_dim + 1
@@ -1896,7 +1895,7 @@ def test_stack_multionehot(self, shape, stack_dim):
18961895 c1 = MultiOneHotDiscreteTensorSpec (nvec , shape = shape )
18971896 c2 = c1 .clone ()
18981897 c = torch .stack ([c1 , c2 ], stack_dim )
1899- assert isinstance (c , LazyStackedTensorSpec )
1898+ assert isinstance (c , MultiOneHotDiscreteTensorSpec )
19001899 shape = list (shape )
19011900 if stack_dim < 0 :
19021901 stack_dim = len (shape ) + stack_dim + 1
@@ -1940,7 +1939,7 @@ def test_stack_onehot(self, shape, stack_dim):
19401939 c1 = OneHotDiscreteTensorSpec (n , shape = shape )
19411940 c2 = c1 .clone ()
19421941 c = torch .stack ([c1 , c2 ], stack_dim )
1943- assert isinstance (c , LazyStackedTensorSpec )
1942+ assert isinstance (c , OneHotDiscreteTensorSpec )
19441943 shape = list (shape )
19451944 if stack_dim < 0 :
19461945 stack_dim = len (shape ) + stack_dim + 1
@@ -1983,7 +1982,7 @@ def test_stack_unboundedcont(self, shape, stack_dim):
19831982 c1 = UnboundedContinuousTensorSpec (shape = shape )
19841983 c2 = c1 .clone ()
19851984 c = torch .stack ([c1 , c2 ], stack_dim )
1986- assert isinstance (c , LazyStackedTensorSpec )
1985+ assert isinstance (c , UnboundedContinuousTensorSpec )
19871986 shape = list (shape )
19881987 if stack_dim < 0 :
19891988 stack_dim = len (shape ) + stack_dim + 1
@@ -2023,7 +2022,7 @@ def test_stack_unboundeddiscrete(self, shape, stack_dim):
20232022 c1 = UnboundedDiscreteTensorSpec (shape = shape )
20242023 c2 = c1 .clone ()
20252024 c = torch .stack ([c1 , c2 ], stack_dim )
2026- assert isinstance (c , LazyStackedTensorSpec )
2025+ assert isinstance (c , UnboundedDiscreteTensorSpec )
20272026 shape = list (shape )
20282027 if stack_dim < 0 :
20292028 stack_dim = len (shape ) + stack_dim + 1
@@ -2064,11 +2063,13 @@ def test_stack(self):
20642063 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec ())
20652064 c2 = c1 .clone ()
20662065 c = torch .stack ([c1 , c2 ], 0 )
2067- assert isinstance (c , LazyStackedCompositeSpec )
2066+ assert isinstance (c , CompositeSpec )
20682067
20692068 def test_stack_index (self ):
20702069 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec ())
2071- c2 = c1 .clone ()
2070+ c2 = CompositeSpec (
2071+ a = UnboundedContinuousTensorSpec (), b = UnboundedDiscreteTensorSpec ()
2072+ )
20722073 c = torch .stack ([c1 , c2 ], 0 )
20732074 assert c .shape == torch .Size ([2 ])
20742075 assert c [0 ] is c1
@@ -2082,7 +2083,11 @@ def test_stack_index(self):
20822083 @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
20832084 def test_stack_index_multdim (self , stack_dim ):
20842085 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2085- c2 = c1 .clone ()
2086+ c2 = CompositeSpec (
2087+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2088+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2089+ shape = (1 , 3 ),
2090+ )
20862091 c = torch .stack ([c1 , c2 ], stack_dim )
20872092 if stack_dim in (0 , - 3 ):
20882093 assert isinstance (c [:], LazyStackedCompositeSpec )
@@ -2146,36 +2151,14 @@ def test_stack_index_multdim(self, stack_dim):
21462151 assert c [:, :, 0 , ...] is c1
21472152 assert c [:, :, 1 , ...] is c2
21482153
2149- @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
2150- def test_stack_expand_one (self , stack_dim ):
2151- c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2152- c = torch .stack ([c1 ], stack_dim )
2153- if stack_dim in (0 , - 3 ):
2154- c_expand = c .expand ([4 , 2 , 1 , 3 ])
2155- assert c_expand .shape == torch .Size ([4 , 2 , 1 , 3 ])
2156- assert c_expand .dim == 1
2157- elif stack_dim in (1 , - 2 ):
2158- c_expand = c .expand ([4 , 1 , 2 , 3 ])
2159- assert c_expand .shape == torch .Size ([4 , 1 , 2 , 3 ])
2160- assert c_expand .dim == 2
2161- elif stack_dim in (2 , - 1 ):
2162- c_expand = c .expand (
2163- [
2164- 4 ,
2165- 1 ,
2166- 3 ,
2167- 2 ,
2168- ]
2169- )
2170- assert c_expand .shape == torch .Size ([4 , 1 , 3 , 2 ])
2171- assert c_expand .dim == 3
2172- else :
2173- raise NotImplementedError
2174-
21752154 @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
21762155 def test_stack_expand_multi (self , stack_dim ):
21772156 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2178- c2 = c1 .clone ()
2157+ c2 = CompositeSpec (
2158+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2159+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2160+ shape = (1 , 3 ),
2161+ )
21792162 c = torch .stack ([c1 , c2 ], stack_dim )
21802163 if stack_dim in (0 , - 3 ):
21812164 c_expand = c .expand ([4 , 2 , 1 , 3 ])
@@ -2202,7 +2185,11 @@ def test_stack_expand_multi(self, stack_dim):
22022185 @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
22032186 def test_stack_rand (self , stack_dim ):
22042187 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2205- c2 = c1 .clone ()
2188+ c2 = CompositeSpec (
2189+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2190+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2191+ shape = (1 , 3 ),
2192+ )
22062193 c = torch .stack ([c1 , c2 ], stack_dim )
22072194 r = c .rand ()
22082195 assert isinstance (r , LazyStackedTensorDict )
@@ -2220,7 +2207,11 @@ def test_stack_rand(self, stack_dim):
22202207 @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
22212208 def test_stack_rand_shape (self , stack_dim ):
22222209 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2223- c2 = c1 .clone ()
2210+ c2 = CompositeSpec (
2211+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2212+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2213+ shape = (1 , 3 ),
2214+ )
22242215 c = torch .stack ([c1 , c2 ], stack_dim )
22252216 shape = [5 , 6 ]
22262217 r = c .rand (shape )
@@ -2239,7 +2230,11 @@ def test_stack_rand_shape(self, stack_dim):
22392230 @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
22402231 def test_stack_zero (self , stack_dim ):
22412232 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2242- c2 = c1 .clone ()
2233+ c2 = CompositeSpec (
2234+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2235+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2236+ shape = (1 , 3 ),
2237+ )
22432238 c = torch .stack ([c1 , c2 ], stack_dim )
22442239 r = c .zero ()
22452240 assert isinstance (r , LazyStackedTensorDict )
@@ -2257,7 +2252,11 @@ def test_stack_zero(self, stack_dim):
22572252 @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
22582253 def test_stack_zero_shape (self , stack_dim ):
22592254 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2260- c2 = c1 .clone ()
2255+ c2 = CompositeSpec (
2256+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2257+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2258+ shape = (1 , 3 ),
2259+ )
22612260 c = torch .stack ([c1 , c2 ], stack_dim )
22622261 shape = [5 , 6 ]
22632262 r = c .zero (shape )
@@ -2274,18 +2273,31 @@ def test_stack_zero_shape(self, stack_dim):
22742273 assert (r ["a" ] == 0 ).all ()
22752274
22762275 @pytest .mark .skipif (not torch .cuda .device_count (), reason = "no cuda" )
2277- def test_to (self ):
2276+ @pytest .mark .parametrize ("stack_dim" , [0 , 1 , 2 , - 3 , - 2 , - 1 ])
2277+ def test_to (self , stack_dim ):
22782278 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2279- c2 = c1 .clone ()
2279+ c2 = CompositeSpec (
2280+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2281+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2282+ shape = (1 , 3 ),
2283+ )
22802284 c = torch .stack ([c1 , c2 ], stack_dim )
2285+ assert isinstance (c , LazyStackedCompositeSpec )
22812286 cdevice = c .to ("cuda:0" )
22822287 assert cdevice .device != c .device
22832288 assert cdevice .device == torch .device ("cuda:0" )
2284- assert cdevice [0 ].device == torch .device ("cuda:0" )
2289+ if stack_dim < 0 :
2290+ stack_dim += 3
2291+ index = (slice (None ),) * stack_dim + (0 ,)
2292+ assert cdevice [index ].device == torch .device ("cuda:0" )
22852293
22862294 def test_clone (self ):
22872295 c1 = CompositeSpec (a = UnboundedContinuousTensorSpec (shape = (1 , 3 )), shape = (1 , 3 ))
2288- c2 = c1 .clone ()
2296+ c2 = CompositeSpec (
2297+ a = UnboundedContinuousTensorSpec (shape = (1 , 3 )),
2298+ b = UnboundedDiscreteTensorSpec (shape = (1 , 3 )),
2299+ shape = (1 , 3 ),
2300+ )
22892301 c = torch .stack ([c1 , c2 ], 0 )
22902302 cclone = c .clone ()
22912303 assert cclone [0 ] is not c [0 ]
0 commit comments