diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index 6fc063c2411..fb184291c90 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -258,7 +258,6 @@ def train(cfg: "DictConfig"): # noqa: F821 loss_vals["loss_actor"] + loss_vals["loss_alpha"] + loss_vals["loss_qvalue"] - + loss_vals["loss_alpha"] ) loss_value.backward() diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 4d35b18a360..c5ae154fcfd 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -253,12 +253,11 @@ # # -print("action_spec:", env.action_spec) -print("reward_spec:", env.reward_spec) -print("done_spec:", env.done_spec) +print("action_spec:", env.full_action_spec) +print("reward_spec:", env.full_reward_spec) +print("done_spec:", env.full_done_spec) print("observation_spec:", env.observation_spec) - ###################################################################### # Using the commands just shown we can access the domain of each value. # Doing this we can see that all specs apart from done have a leading shape ``(num_vmas_envs, n_agents)``. @@ -270,35 +269,20 @@ # In fact, specs that have the additional agent dimension # (i.e., they vary for each agent) will be contained in a inner "agents" key. # -# To access the full structure of the specs we can use -# - -print("full_action_spec:", env.input_spec["full_action_spec"]) -print("full_reward_spec:", env.output_spec["full_reward_spec"]) -print("full_done_spec:", env.output_spec["full_done_spec"]) - -###################################################################### # As you can see the reward and action spec present the "agent" key, # meaning that entries in tensordicts belonging to those specs will be nested in an "agents" tensordict, # grouping all per-agent values. # -# To quickly access the key for each of these values in tensordicts, we can simply ask the environment for the -# respective key, and +# To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the +# respective keys, and # we will immediately understand which are per-agent and which shared. # This info will be useful in order to tell all other TorchRL components where to find each value # -print("action_key:", env.action_key) -print("reward_key:", env.reward_key) -print("done_key:", env.done_key) - -###################################################################### -# To tie it all together, we can see that passing these keys to the full specs gives us the leaf domains -# +print("action_keys:", env.action_keys) +print("reward_keys:", env.reward_keys) +print("done_keys:", env.done_keys) -assert env.action_spec == env.input_spec["full_action_spec"][env.action_key] -assert env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key] -assert env.done_spec == env.output_spec["full_done_spec"][env.done_key] ###################################################################### # Transforms @@ -615,6 +599,9 @@ action=env.action_key, sample_log_prob=("agents", "sample_log_prob"), value=("agents", "state_value"), + # These last 2 keys will be expanded to match the reward shape + done=("agents", "done"), + terminated=("agents", "terminated"), ) @@ -649,11 +636,18 @@ episode_reward_mean_list = [] for tensordict_data in collector: tensordict_data.set( - ("next", "done"), + ("next", "agents", "done"), tensordict_data.get(("next", "done")) .unsqueeze(-1) - .expand(tensordict_data.get(("next", env.reward_key)).shape), - ) # We need to expand the done to match the reward shape (this is expected by the value estimator) + .expand(tensordict_data.get_item_shape(("next", env.reward_key))), + ) + tensordict_data.set( + ("next", "agents", "terminated"), + tensordict_data.get(("next", "terminated")) + .unsqueeze(-1) + .expand(tensordict_data.get_item_shape(("next", env.reward_key))), + ) + # We need to expand the done and terminated to match the reward shape (this is expected by the value estimator) with torch.no_grad(): GAE( @@ -688,7 +682,7 @@ collector.update_policy_weights_() # Logging - done = tensordict_data.get(("next", "done")) + done = tensordict_data.get(("next", "agents", "done")) episode_reward_mean = ( tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item() )