Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 173 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from tensordict.utils import expand_right, NestedKey

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand Down Expand Up @@ -1290,3 +1290,175 @@ def _step(
device=self.device,
)
return tensordict.select().set("next", tensordict)


class HeteroCountingEnv(EnvBase):
"""A heterogeneous, counting Env."""

def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
super().__init__(**kwargs)
self.n_agents = 3
self.max_steps = max_steps
self.start_val = start_val

count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int)
count[:] = self.start_val

self.register_buffer("count", count)

agent_obs_specs = []
agent_action_specs = []
for angent_id in range(self.n_agents):
agent_obs_specs.append(self.get_agent_obs_spec(angent_id))
agent_action_specs.append(self.get_agent_action_spec(angent_id))
agent_obs_specs = torch.stack(agent_obs_specs, dim=0)
agent_action_specs = torch.stack(agent_action_specs, dim=0)

self.unbatched_observation_spec = CompositeSpec(
agents=agent_obs_specs,
state=UnboundedContinuousTensorSpec(
shape=(
64,
64,
3,
)
),
)

self.unbatched_action_spec = CompositeSpec(
agents=agent_action_specs,
)
self.unbatched_reward_spec = CompositeSpec(
{
"agents": CompositeSpec(
{"reward": UnboundedContinuousTensorSpec(shape=(self.n_agents, 1))},
shape=(self.n_agents,),
)
}
)
self.unbatched_done_spec = CompositeSpec(
{
"agents": CompositeSpec(
{
"done": DiscreteTensorSpec(
n=2,
shape=(self.n_agents, 1),
dtype=torch.bool,
),
},
shape=(self.n_agents,),
)
}
)

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def get_agent_obs_spec(self, i):
camera = BoundedTensorSpec(minimum=0, maximum=1, shape=(32, 32, 3))
vector_3d = UnboundedContinuousTensorSpec(shape=(3,))
vector_2d = UnboundedContinuousTensorSpec(shape=(2,))
lidar = BoundedTensorSpec(minimum=0, maximum=5, shape=(20,))

agent_0_obs = UnboundedContinuousTensorSpec(shape=(1,))
agent_1_obs = BoundedTensorSpec(minimum=0, maximum=3, shape=(1, 2))
agent_2_obs = UnboundedContinuousTensorSpec(shape=(1, 2, 3))

# Agents all have the same camera
# All have vector entry but different shapes
# First 2 have lidar and last sonar
# All have a different key agent_i_obs with different n_dims
if i == 0:
return CompositeSpec(
{
"camera": camera,
"lidar": lidar,
"vector": vector_3d,
"agent_0_obs": agent_0_obs,
}
)
elif i == 1:
return CompositeSpec(
{
"camera": camera,
"lidar": lidar,
"vector": vector_2d,
"agent_1_obs": agent_1_obs,
}
)
elif i == 2:
return CompositeSpec(
{
"camera": camera,
"vector": vector_2d,
"agent_2_obs": agent_2_obs,
}
)
else:
raise ValueError(f"Index {i} undefined for 3 agents")

def get_agent_action_spec(self, i):
force_3d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(3,))
force_2d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(2,))

# Some have 2d action and some 3d
# TODO Introduce composite heterogeneous actions
if i == 0:
ret = force_3d
elif i == 1:
ret = force_2d
elif i == 2:
ret = force_2d
else:
raise ValueError(f"Index {i} undefined for 3 agents")

return CompositeSpec({"action": ret})

def _reset(
self,
tensordict: TensorDictBase = None,
**kwargs,
) -> TensorDictBase:
if tensordict is not None and "_reset" in tensordict.keys():
_reset = tensordict.get("_reset")
self.count[_reset] = self.start_val
else:
self.count[:] = self.start_val

reset_td = self.observation_spec.zero()
reset_td.apply_(lambda x: x + expand_right(self.count, x.shape))
reset_td.update(self.output_spec["_done_spec"].zero())

assert reset_td.batch_size == self.batch_size

return reset_td

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
td = self.observation_spec.zero()
self.count += 1
td.apply_(lambda x: x + expand_right(self.count, x.shape))
td.update(self.output_spec["_done_spec"].zero())
td.update(self.output_spec["_reward_spec"].zero())

assert td.batch_size == self.batch_size
td[self.done_key] = expand_right(
self.count > self.max_steps, self.done_spec.shape
)

return td.select().set("next", td)

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)
28 changes: 25 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
HeteroCountingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MockSerialEnv,
Expand Down Expand Up @@ -1417,7 +1418,6 @@ def test_batch_unlocked(device):
env.step(td_expanded)


@pytest.mark.parametrize("device", get_default_devices())
def test_batch_unlocked_with_batch_size(device):
env = MockBatchedUnLockedEnv(device, batch_size=torch.Size([2]))
assert not env.batch_locked
Expand Down Expand Up @@ -1669,7 +1669,6 @@ def test_mp_collector(self, nproc):
class TestNestedSpecs:
@pytest.mark.parametrize("envclass", ["CountingEnv", "NestedCountingEnv"])
def test_nested_env(self, envclass):

if envclass == "CountingEnv":
env = CountingEnv()
elif envclass == "NestedCountingEnv":
Expand Down Expand Up @@ -1700,7 +1699,6 @@ def test_nested_env(self, envclass):

@pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)])
def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):

env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim)

td_reset = env.reset()
Expand Down Expand Up @@ -1750,6 +1748,29 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):
)


class TestHeteroEnvs:
@pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])
def test_reset(self, batch_size):
env = HeteroCountingEnv(batch_size=batch_size)
env.reset()

@pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])
def test_rand_step(self, batch_size):
env = HeteroCountingEnv(batch_size=batch_size)
td = env.reset()
assert (td["agents"][..., 0]["agent_0_obs"] == 0).all()
td = env.rand_step()
assert (td["next", "agents"][..., 0]["agent_0_obs"] == 1).all()
td = env.rand_step()
assert (td["next", "agents"][..., 1]["agent_1_obs"] == 2).all()

@pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])
def test_rollout_one(self, batch_size, rollout_steps=1):
env = HeteroCountingEnv(batch_size=batch_size)
td = env.rollout(rollout_steps)
td.get("agents")


@pytest.mark.parametrize(
"envclass",
[
Expand All @@ -1768,6 +1789,7 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):
MockBatchedUnLockedEnv,
MockSerialEnv,
NestedCountingEnv,
HeteroCountingEnv,
],
)
def test_mocking_envs(envclass):
Expand Down
Loading