Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cd70b16
init
vmoens Jul 28, 2023
4126035
init
vmoens Jul 28, 2023
888fee5
amend
vmoens Aug 10, 2023
5fcb42c
init
vmoens Aug 11, 2023
871589e
amend
vmoens Aug 11, 2023
858da29
amend
vmoens Aug 11, 2023
473b200
Merge branch 'main' into event_mp
vmoens Aug 31, 2023
871bdc7
amend
vmoens Aug 31, 2023
bb352e5
amend
vmoens Aug 31, 2023
28cb428
amend
vmoens Aug 31, 2023
26b53bc
amend
vmoens Aug 31, 2023
77872a6
amend
vmoens Aug 31, 2023
c7ed82e
amend
vmoens Aug 31, 2023
fd089d1
lint
vmoens Aug 31, 2023
d6f304a
fixes
vmoens Aug 31, 2023
df0210e
amend
vmoens Aug 31, 2023
78a06d5
amend
vmoens Aug 31, 2023
45615c3
amend
vmoens Aug 31, 2023
2138931
Merge branch 'event_mp' into fix_lstm_penv
vmoens Aug 31, 2023
bc3abd2
amend
vmoens Aug 31, 2023
e0d81ef
tmp
vmoens Aug 31, 2023
b839208
Merge branch 'main' into fix_lstm_penv
vmoens Sep 1, 2023
9fc95fc
amend
vmoens Sep 1, 2023
8f3ed5e
amend
vmoens Sep 1, 2023
ff4bb70
Merge remote-tracking branch 'origin/main' into fix_lstm_penv
vmoens Sep 1, 2023
e6fa755
amend
vmoens Sep 1, 2023
38f74a2
amend
vmoens Sep 1, 2023
b568c3b
amend
vmoens Sep 1, 2023
71d1076
amend
vmoens Sep 1, 2023
5f7885e
init
vmoens Sep 1, 2023
6b46000
Merge branch 'main' into gru
vmoens Oct 1, 2023
b0434d9
Merge remote-tracking branch 'origin/main' into gru
vmoens Oct 4, 2023
b053bee
amend
vmoens Oct 4, 2023
274daef
Merge remote-tracking branch 'origin/main' into gru
vmoens Oct 5, 2023
c40e3db
Update test/mocking_classes.py
Oct 5, 2023
cec7987
Merge branch 'gru' of github.com:pytorch/rl into gru
vmoens Oct 5, 2023
da7e173
amend
vmoens Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ algorithms, such as DQN, DDPG or Dreamer.
DistributionalDQNnet
DreamerActor
DuelingCnnDQNet
GRUModule
LSTMModule
ObsDecoder
ObsEncoder
Expand Down
262 changes: 260 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AdditiveGaussianWrapper,
DecisionTransformerInferenceWrapper,
DTActor,
GRUModule,
LSTMModule,
MLP,
NormalParamWrapper,
Expand Down Expand Up @@ -1645,9 +1646,9 @@ def test_set_temporal_mode(self):
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
)
assert lstm_module.set_recurrent_mode(False) is lstm_module
assert not lstm_module.set_recurrent_mode(False).temporal_mode
assert not lstm_module.set_recurrent_mode(False).recurrent_mode
assert lstm_module.set_recurrent_mode(True) is not lstm_module
assert lstm_module.set_recurrent_mode(True).temporal_mode
assert lstm_module.set_recurrent_mode(True).recurrent_mode
assert set(lstm_module.set_recurrent_mode(True).parameters()) == set(
lstm_module.parameters()
)
Expand Down Expand Up @@ -1822,6 +1823,263 @@ def create_transformed_env():
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()


class TestGRUModule:
def test_errs(self):
with pytest.raises(ValueError, match="batch_first"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=False,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(ValueError, match="in_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=[
"observation",
"hidden0",
"hidden1",
],
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(TypeError, match="incompatible function arguments"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys="abc",
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(ValueError, match="in_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_key="smth",
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden")],
)
with pytest.raises(ValueError, match="out_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden"), "other"],
)
with pytest.raises(TypeError, match="incompatible function arguments"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys="abc",
)
with pytest.raises(ValueError, match="out_keys"):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_key="smth",
out_keys=["intermediate", ("next", "hidden"), "other"],
)
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
td = TensorDict({"observation": torch.randn(3)}, [])
with pytest.raises(KeyError, match="is_init"):
gru_module(td)

def test_set_temporal_mode(self):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
assert gru_module.set_recurrent_mode(False) is gru_module
assert not gru_module.set_recurrent_mode(False).recurrent_mode
assert gru_module.set_recurrent_mode(True) is not gru_module
assert gru_module.set_recurrent_mode(True).recurrent_mode
assert set(gru_module.set_recurrent_mode(True).parameters()) == set(
gru_module.parameters()
)

def test_noncontiguous(self):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["bork", "h"],
out_keys=["dork", ("next", "h")],
)
td = TensorDict(
{
"bork": torch.randn(3, 3),
"is_init": torch.zeros(3, 1, dtype=torch.bool),
},
[3],
)
padded = pad(td, [0, 5])
gru_module(padded)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_singel_step(self, shape):
td = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
td = gru_module(td)
td_next = step_mdp(td, keep_other=True)
td_next = gru_module(td_next)

assert not torch.isclose(td_next["next", "hidden"], td["next", "hidden"]).any()

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
@pytest.mark.parametrize("t", [1, 10])
def test_single_step_vs_multi(self, shape, t):
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
.unsqueeze(-1)
.expand(*shape, t, 3),
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
},
[*shape, t],
)
gru_module_ss = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
gru_module_ms = gru_module_ss.set_recurrent_mode()
gru_module_ms(td)
td_ss = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
for _t in range(t):
gru_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :])

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_multi_consecutive(self, shape):
t = 20
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
.unsqueeze(-1)
.expand(*shape, t, 3),
"is_init": torch.zeros(*shape, t, 1, dtype=torch.bool),
},
[*shape, t],
)
if shape:
td["is_init"][0, ..., 13, :] = True
else:
td["is_init"][13, :] = True

gru_module_ss = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
)
gru_module_ms = gru_module_ss.set_recurrent_mode()
gru_module_ms(td)
td_ss = TensorDict(
{
"observation": torch.zeros(*shape, 3),
"is_init": torch.zeros(*shape, 1, dtype=torch.bool),
},
shape,
)
for _t in range(t):
td_ss["is_init"][:] = td["is_init"][..., _t, :]
gru_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
torch.testing.assert_close(
td_ss["intermediate"], td["intermediate"][..., -1, :]
)

def test_gru_parallel_env(self):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

# tests that hidden states are carried over with parallel envs
gru_module = GRUModule(
input_size=7,
hidden_size=12,
num_layers=2,
in_key="observation",
out_key="features",
)

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(categorical_action_encoding=True)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

env = ParallelEnv(
create_env_fn=create_transformed_env,
num_workers=2,
)

mlp = TensorDictModule(
MLP(
in_features=12,
out_features=7,
num_cells=[],
),
in_keys=["features"],
out_keys=["logits"],
)

actor_model = TensorDictSequential(gru_module, mlp)

actor = ProbabilisticActor(
module=actor_model,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=True,
)
for break_when_any_done in [False, True]:
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()


def test_safe_specs():

out_key = ("a", "b")
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
DistributionalQValueModule,
EGreedyModule,
EGreedyWrapper,
GRUModule,
LMHeadActorValueOperator,
LSTMModule,
OrnsteinUhlenbeckProcessWrapper,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from .rnn import LSTMModule
from .rnn import GRUModule, LSTMModule
from .sequence import SafeSequential
from .world_models import WorldModelWrapper
Loading