From 4aae99a732ed46a38a73a64dbeecd6ad50b84490 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Oct 2023 14:14:29 +0100 Subject: [PATCH] init --- torchrl/objectives/value/advantages.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index db056d5ac4d..acd2307a0c3 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -542,7 +542,7 @@ def forward( >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -743,7 +743,7 @@ def forward( >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -955,7 +955,7 @@ def forward( >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -1198,7 +1198,7 @@ def forward( >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: